From 0c7bbc4f0045a15f20b1137fe54c640ffebcdae6 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 15 Nov 2023 23:46:27 +0100 Subject: [PATCH 01/73] finished reworking initial components --- transformer_lens/components.py | 1306 ----------------- transformer_lens/components/BertEmbed.py | 62 + transformer_lens/components/__init__.py | 15 + transformer_lens/components/attention.py | 566 +++++++ transformer_lens/components/bert_block.py | 89 ++ transformer_lens/components/bert_embed.py | 62 + transformer_lens/components/bert_mlm_head.py | 42 + transformer_lens/components/embed.py | 35 + transformer_lens/components/gated_mlp.py | 110 ++ transformer_lens/components/layer_norm.py | 57 + transformer_lens/components/layer_norm_pre.py | 55 + transformer_lens/components/mlp.py | 81 + transformer_lens/components/pos_embed.py | 73 + transformer_lens/components/rms_norm.py | 52 + transformer_lens/components/rms_norm_pre.py | 39 + .../components/token_typed_embed.py | 32 + .../components/transformer_block.py | 189 +++ transformer_lens/components/unembed.py | 39 + 18 files changed, 1598 insertions(+), 1306 deletions(-) delete mode 100644 transformer_lens/components.py create mode 100644 transformer_lens/components/BertEmbed.py create mode 100644 transformer_lens/components/__init__.py create mode 100644 transformer_lens/components/attention.py create mode 100644 transformer_lens/components/bert_block.py create mode 100644 transformer_lens/components/bert_embed.py create mode 100644 transformer_lens/components/bert_mlm_head.py create mode 100644 transformer_lens/components/embed.py create mode 100644 transformer_lens/components/gated_mlp.py create mode 100644 transformer_lens/components/layer_norm.py create mode 100644 transformer_lens/components/layer_norm_pre.py create mode 100644 transformer_lens/components/mlp.py create mode 100644 transformer_lens/components/pos_embed.py create mode 100644 transformer_lens/components/rms_norm.py create mode 100644 transformer_lens/components/rms_norm_pre.py create mode 100644 transformer_lens/components/token_typed_embed.py create mode 100644 transformer_lens/components/transformer_block.py create mode 100644 transformer_lens/components/unembed.py diff --git a/transformer_lens/components.py b/transformer_lens/components.py deleted file mode 100644 index 9c8d663cd..000000000 --- a/transformer_lens/components.py +++ /dev/null @@ -1,1306 +0,0 @@ -"""Hooked Transformer Components. - -This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) -needed to create many different types of generative language models. They are used by -:class:`transformer_lens.HookedTransformer`. -""" -import logging -from typing import Dict, Optional, Tuple, Union - -import einops -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from fancy_einsum import einsum -from jaxtyping import Float, Int - -from transformer_lens.FactoredMatrix import FactoredMatrix -from transformer_lens.hook_points import HookPoint -from transformer_lens.HookedTransformerConfig import HookedTransformerConfig -from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry -from transformer_lens.utils import gelu_fast, gelu_new, get_offset_position_ids, solu - - -# Embed & Unembed -class Embed(nn.Module): - def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): - super().__init__() - if isinstance(cfg, Dict): - cfg = HookedTransformerConfig.from_dict(cfg) - self.cfg = cfg - self.W_E: Float[torch.Tensor, "d_vocab d_model"] = nn.Parameter( - torch.empty(self.cfg.d_vocab, self.cfg.d_model, dtype=cfg.dtype) - ) - # Some models (e.g. Bloom) need post embedding layer norm - if cfg.post_embedding_ln: - self.ln = LayerNorm(cfg) - - def forward( - self, tokens: Int[torch.Tensor, "batch pos"] - ) -> Float[torch.Tensor, "batch pos d_model"]: - # If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d] - # B acts as a tensor of indices into the second dimension (so >=0 and Float[torch.Tensor, "batch pos d_vocab_out"]: - return ( - einsum( - "batch pos d_model, d_model vocab -> batch pos vocab", - residual, - self.W_U, - ) - + self.b_U - ) - - -# Positional Embeddings -class PosEmbed(nn.Module): - def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): - super().__init__() - if isinstance(cfg, Dict): - cfg = HookedTransformerConfig.from_dict(cfg) - self.cfg = cfg - self.W_pos = nn.Parameter( - torch.empty(self.cfg.n_ctx, self.cfg.d_model, dtype=cfg.dtype) - ) - - def forward( - self, - tokens: Int[torch.Tensor, "batch pos"], - past_kv_pos_offset: int = 0, - attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, - ) -> Float[torch.Tensor, "batch pos d_model"]: - """ - Forward pass for positional embeddings. - - Args: - tokens (Int[torch.Tensor, "batch pos"]): Input tokens. - past_kv_pos_offset (int, optional): The length of tokens in the past_kv_cache. Defaults to 0. - attention_mask (Int[torch.Tensor, "batch pos"], optional): The attention mask for padded tokens. - Defaults to None. - - Returns: - Float[torch.Tensor, "batch pos d_model"]: Absolute position embeddings. - """ - tokens_length = tokens.size(-1) - - if attention_mask is None: - pos_embed = self.W_pos[ - past_kv_pos_offset : tokens_length + past_kv_pos_offset, : - ] # [pos, d_model] - batch_pos_embed = einops.repeat( - pos_embed, "pos d_model -> batch pos d_model", batch=tokens.size(0) - ) - - else: - # Separated from the no padding case for computational efficiency - # (this code is a bit slower than the code above) - - offset_position_ids = get_offset_position_ids( - past_kv_pos_offset, attention_mask - ) - pos_embed = self.W_pos[offset_position_ids] # [batch, pos, d_model] - - # Set the position embeddings to 0 for pad tokens (this is an arbitrary choice) - padding_mask = ~attention_mask.bool() # [batch, tokens_length] - offset_padding_mask = padding_mask[ - :, past_kv_pos_offset : tokens_length + past_kv_pos_offset - ].unsqueeze( - -1 - ) # [batch, pos, 1] - batch_pos_embed = torch.where(offset_padding_mask, 0, pos_embed) - - return batch_pos_embed.clone() - - -class TokenTypeEmbed(nn.Module): - """ - The token-type embed is a binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length). - - See the BERT paper for more information: https://arxiv.org/pdf/1810.04805.pdf - """ - - def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): - super().__init__() - if isinstance(cfg, Dict): - cfg = HookedTransformerConfig.from_dict(cfg) - self.cfg = cfg - self.W_token_type = nn.Parameter( - torch.empty(2, self.cfg.d_model, dtype=cfg.dtype) - ) - - def forward(self, token_type_ids: Int[torch.Tensor, "batch pos"]): - return self.W_token_type[token_type_ids, :] - - -class BertEmbed(nn.Module): - """ - Custom embedding layer for a BERT-like model. This module computes the sum of the token, positional and token-type embeddings and takes the layer norm of the result. - """ - - def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): - super().__init__() - if isinstance(cfg, Dict): - cfg = HookedTransformerConfig.from_dict(cfg) - self.cfg = cfg - self.embed = Embed(cfg) - self.pos_embed = PosEmbed(cfg) - self.token_type_embed = TokenTypeEmbed(cfg) - self.ln = LayerNorm(cfg) - - self.hook_embed = HookPoint() - self.hook_pos_embed = HookPoint() - self.hook_token_type_embed = HookPoint() - - def forward( - self, - input_ids: Int[torch.Tensor, "batch pos"], - token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, - ): - base_index_id = torch.arange(input_ids.shape[1], device=input_ids.device) - index_ids = einops.repeat( - base_index_id, "pos -> batch pos", batch=input_ids.shape[0] - ) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - word_embeddings_out = self.hook_embed(self.embed(input_ids)) - position_embeddings_out = self.hook_pos_embed(self.pos_embed(index_ids)) - token_type_embeddings_out = self.hook_token_type_embed( - self.token_type_embed(token_type_ids) - ) - - embeddings_out = ( - word_embeddings_out + position_embeddings_out + token_type_embeddings_out - ) - layer_norm_out = self.ln(embeddings_out) - return layer_norm_out - - -class BertMLMHead(nn.Module): - """ - Transforms BERT embeddings into logits. The purpose of this module is to predict masked tokens in a sentence. - """ - - def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): - super().__init__() - if isinstance(cfg, Dict): - cfg = HookedTransformerConfig.from_dict(cfg) - self.cfg = cfg - self.W = nn.Parameter(torch.empty(cfg.d_model, cfg.d_model, dtype=cfg.dtype)) - self.b = nn.Parameter(torch.zeros(cfg.d_model, dtype=cfg.dtype)) - self.act_fn = nn.GELU() - self.ln = LayerNorm(cfg) - - def forward(self, resid: Float[torch.Tensor, "batch pos d_model"]) -> torch.Tensor: - resid = ( - einsum( - "batch pos d_model_in, d_model_out d_model_in -> batch pos d_model_out", - resid, - self.W, - ) - + self.b - ) - resid = self.act_fn(resid) - resid = self.ln(resid) - return resid - - -# LayerNormPre -# I fold the LayerNorm weights and biases into later weights and biases. -# This is just the 'center and normalise' part of LayerNorm -# Centering is equivalent to just deleting one direction of residual space, -# and is equivalent to centering the weight matrices of everything writing to the residual stream -# Normalising is a funkier non-linear operation, that projects the residual stream onto the unit hypersphere -class LayerNormPre(nn.Module): - def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): - """LayerNormPre - the 'center and normalise' part of LayerNorm. Length is - normally d_model, but is d_mlp for softmax. Not needed as a parameter. This - should only be used in inference mode after folding in LayerNorm weights""" - super().__init__() - if isinstance(cfg, Dict): - cfg = HookedTransformerConfig.from_dict(cfg) - self.cfg = cfg - self.eps = self.cfg.eps - - # Adds a hook point for the normalisation scale factor - self.hook_scale = HookPoint() # [batch, pos] - # Hook Normalized captures LN output - here it's a vector with std 1 and mean 0 - self.hook_normalized = HookPoint() # [batch, pos, length] - - def forward( - self, - x: Union[ - Float[torch.Tensor, "batch pos d_model"], - Float[torch.Tensor, "batch pos head_index d_model"], - ], - ) -> Union[ - Float[torch.Tensor, "batch pos d_model"], - Float[torch.Tensor, "batch pos head_index d_model"], - ]: - if self.cfg.dtype not in [torch.float32, torch.float64]: - x = x.to(torch.float32) - - x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, length] - scale: Union[ - Float[torch.Tensor, "batch pos 1"], - Float[torch.Tensor, "batch pos head_index 1"], - ] = self.hook_scale((x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()) - return self.hook_normalized(x / scale).to(self.cfg.dtype) - - -class LayerNorm(nn.Module): - def __init__( - self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None - ): - """ - LayerNorm with optional length parameter - - length (Optional[int]): If the dimension of the LayerNorm. If not provided, assumed to be d_model - """ - super().__init__() - if isinstance(cfg, Dict): - cfg = HookedTransformerConfig.from_dict(cfg) - self.cfg = cfg - self.eps = self.cfg.eps - if length is None: - self.length = self.cfg.d_model - else: - self.length = length - - self.w = nn.Parameter(torch.ones(self.length, dtype=cfg.dtype)) - self.b = nn.Parameter(torch.zeros(self.length, dtype=cfg.dtype)) - - # Adds a hook point for the normalisation scale factor - self.hook_scale = HookPoint() # [batch, pos, 1] - # Hook_normalized is on the LN output - self.hook_normalized = HookPoint() # [batch, pos, length] - - def forward( - self, - x: Union[ - Float[torch.Tensor, "batch pos d_model"], - Float[torch.Tensor, "batch pos head_index d_model"], - ], - ) -> Union[ - Float[torch.Tensor, "batch pos d_model"], - Float[torch.Tensor, "batch pos head_index d_model"], - ]: - if self.cfg.dtype not in [torch.float32, torch.float64]: - x = x.to(torch.float32) - - x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, length] - scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( - (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() - ) - x = x / scale # [batch, pos, length] - return self.hook_normalized(x * self.w + self.b).to(self.cfg.dtype) - - -class RMSNormPre(nn.Module): - def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): - """RMSNormPre - LayerNormPre without the centering and bias (RMS = Root Mean Square)""" - super().__init__() - if isinstance(cfg, Dict): - cfg = HookedTransformerConfig.from_dict(cfg) - self.cfg = cfg - self.eps = self.cfg.eps - - # Adds a hook point for the normalisation scale factor - self.hook_scale = HookPoint() # [batch, pos] - self.hook_normalized = HookPoint() # [batch, pos, length] - - def forward( - self, x: Float[torch.Tensor, "batch pos length"] - ) -> Float[torch.Tensor, "batch pos length"]: - if self.cfg.dtype not in [torch.float32, torch.float64]: - x = x.to(torch.float32) - - scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( - (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() - ) - return self.hook_normalized(x / scale).to( - self.cfg.dtype - ) # [batch, pos, length] - - -class RMSNorm(nn.Module): - def __init__( - self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None - ): - """ - RMSNorm - LayerNorm without the centering and bias (RMS = Root Mean Square) - - length (Optional[int]): If the dimension of the RMSNorm. If not provided, assumed to be d_model - """ - super().__init__() - if isinstance(cfg, Dict): - cfg = HookedTransformerConfig.from_dict(cfg) - self.cfg = cfg - self.eps = self.cfg.eps - if length is None: - self.length = self.cfg.d_model - else: - self.length = length - - self.w = nn.Parameter(torch.ones(self.length, dtype=cfg.dtype)) - - # Adds a hook point for the normalisation scale factor - self.hook_scale = HookPoint() # [batch, pos, 1] - self.hook_normalized = HookPoint() # [batch, pos, length] - - def forward( - self, x: Float[torch.Tensor, "batch pos length"] - ) -> Float[torch.Tensor, "batch pos length"]: - if self.cfg.dtype not in [torch.float32, torch.float64]: - x = x.to(torch.float32) - - scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( - (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() - ) - x = self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length] - return x * self.w - - -# Attention -class Attention(nn.Module): - def __init__( - self, - cfg: Union[Dict, HookedTransformerConfig], - attn_type: str = "global", - layer_id: Optional[int] = None, - ): - """Attention Block - params have shape [head_index, d_model, d_head] (or [head_index, d_head, d_model] for W_O) and multiply on the right. attn_scores refers to query key dot product immediately before attention softmax - - Convention: All attention pattern-style matrices have shape [batch, head_index, query_pos, key_pos] - - Args: - cfg (Union[Dict, HookedTransformerConfig]): Config - attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global". - layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. - """ - super().__init__() - if isinstance(cfg, Dict): - cfg = HookedTransformerConfig.from_dict(cfg) - self.cfg = cfg - self.W_Q = nn.Parameter( - torch.empty( - self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype - ) - ) - self.W_K = nn.Parameter( - torch.empty( - self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype - ) - ) - self.W_V = nn.Parameter( - torch.empty( - self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype - ) - ) - self.W_O = nn.Parameter( - torch.empty( - self.cfg.n_heads, self.cfg.d_head, self.cfg.d_model, dtype=cfg.dtype - ) - ) - self.b_Q = nn.Parameter( - torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype) - ) - self.b_K = nn.Parameter( - torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype) - ) - self.b_V = nn.Parameter( - torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype) - ) - self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype)) - - self.attn_type = attn_type - # Create a max_ctx x max_ctx mask, with True iff that query position - # can attend to that key position (query is first axis, key is second axis) - causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool()) - if self.attn_type == "global": - # For global attention, this is a lower triangular matrix - key <= query - self.register_buffer("mask", causal_mask) - elif self.attn_type == "local": - # For local, this is banded, query - window_size < key <= query - assert isinstance(self.cfg.window_size, int) - self.register_buffer( - "mask", torch.triu(causal_mask, 1 - self.cfg.window_size) - ) - else: - raise ValueError(f"Invalid attention type: {self.attn_type}") - - self.register_buffer("IGNORE", torch.tensor(-torch.inf)) - - self.layer_id = layer_id - - # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability? - if self.cfg.use_attn_scale: - self.attn_scale = np.sqrt(self.cfg.d_head) - else: - self.attn_scale = 1.0 - if self.cfg.scale_attn_by_inverse_layer_idx: - self.attn_scale *= self.layer_id + 1 - - self.hook_k = HookPoint() # [batch, pos, head_index, d_head] - self.hook_q = HookPoint() # [batch, pos, head_index, d_head] - self.hook_v = HookPoint() # [batch, pos, head_index, d_head] - self.hook_z = HookPoint() # [batch, pos, head_index, d_head] - self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] - self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos] - self.hook_result = HookPoint() # [batch, pos, head_index, d_model] - - # See HookedTransformerConfig for more details. - if self.cfg.positional_embedding_type == "shortformer": - # This tracks the input to the keys and queries, which is resid_pre + pos_embeds - self.hook_attn_input = HookPoint() # [batch, pos, d_model] - elif self.cfg.positional_embedding_type == "rotary": - # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details - self.hook_rot_k = HookPoint() - self.hook_rot_q = HookPoint() - sin, cos = self.calculate_sin_cos_rotary( - self.cfg.rotary_dim, self.cfg.n_ctx, dtype=self.cfg.dtype - ) - self.register_buffer("rotary_sin", sin) - self.register_buffer("rotary_cos", cos) - elif self.cfg.positional_embedding_type == "alibi": - # ALiBi bias wil be constructed on the first forward pass. - # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage. - self.alibi = None - - @property - def OV(self) -> FactoredMatrix: - """ - OV-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more) - - Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry! - - Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works. - """ - return FactoredMatrix(self.W_V, self.W_O) - - @property - def QK(self) -> FactoredMatrix: - """ - QK-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more). - - Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos] - - Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works. - """ - W_K_transpose = einops.rearrange( - self.W_K, "head_index d_model d_head -> head_index d_head d_model" - ) - return FactoredMatrix(self.W_Q, W_K_transpose) - - def forward( - self, - query_input: Union[ - Float[torch.Tensor, "batch pos d_model"], - Float[torch.Tensor, "batch pos head_index d_model"], - ], - key_input: Union[ - Float[torch.Tensor, "batch pos d_model"], - Float[torch.Tensor, "batch pos head_index d_model"], - ], - value_input: Union[ - Float[torch.Tensor, "batch pos d_model"], - Float[torch.Tensor, "batch pos head_index d_model"], - ], - past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, - additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None, - attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, - ) -> Float[torch.Tensor, "batch pos d_model"]: - """ - shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details - past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None - additive_attention_mask is an optional mask to add to the attention weights. Defaults to None. - attention_mask is the attention mask for padded tokens. Defaults to None. - """ - - if self.cfg.use_split_qkv_input or self.cfg.use_attn_in: - qkv_einops_string = "batch pos head_index d_model" - else: - qkv_einops_string = "batch pos d_model" - q = self.hook_q( - einsum( - f"{qkv_einops_string}, head_index d_model d_head \ - -> batch pos head_index d_head", - query_input, - self.W_Q, - ) - + self.b_Q - ) # [batch, pos, head_index, d_head] - k = self.hook_k( - einsum( - f"{qkv_einops_string}, head_index d_model d_head \ - -> batch pos head_index d_head", - key_input, - self.W_K, - ) - + self.b_K - ) # [batch, pos, head_index, d_head] - v = self.hook_v( - einsum( - f"{qkv_einops_string}, head_index d_model d_head \ - -> batch pos head_index d_head", - value_input, - self.W_V, - ) - + self.b_V - ) # [batch, pos, head_index, d_head] - - if past_kv_cache_entry is not None: - # Appends the new keys and values to the cached values, and automatically updates the cache - kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1) - k, v = past_kv_cache_entry.append(k, v) - else: - # Not using a cache - kv_cache_pos_offset = 0 - - if self.cfg.positional_embedding_type == "rotary": - q = self.hook_rot_q( - self.apply_rotary(q, kv_cache_pos_offset, attention_mask) - ) - k = self.hook_rot_k( - self.apply_rotary(k, 0, attention_mask) - ) # keys are cached so no offset - - if self.cfg.dtype not in [torch.float32, torch.float64]: - # If using 16 bits, increase the precision to avoid numerical instabilities - q = q.to(torch.float32) - k = k.to(torch.float32) - - attn_scores = ( - einsum( - "batch query_pos head_index d_head, \ - batch key_pos head_index d_head \ - -> batch head_index query_pos key_pos", - q, - k, - ) - / self.attn_scale - ) # [batch, head_index, query_pos, key_pos] - - if self.cfg.positional_embedding_type == "alibi": - query_ctx = attn_scores.size(-2) - # The key context length is the number of positions in the past - this includes all positions in the cache - key_ctx = attn_scores.size(-1) - - # only recompute when necessary to increase efficiency. - if self.alibi is None or key_ctx > self.alibi.size(-1): - self.alibi = Attention.create_alibi_bias( - self.cfg.n_heads, key_ctx, self.cfg.device - ) - - attn_scores += self.alibi[ - :, :query_ctx, :key_ctx - ] # [batch, head_index, query_pos, key_pos] - - if self.cfg.attention_dir == "causal": - # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. - attn_scores = self.apply_causal_mask( - attn_scores, kv_cache_pos_offset, attention_mask - ) # [batch, head_index, query_pos, key_pos] - if additive_attention_mask is not None: - attn_scores += additive_attention_mask - - attn_scores = self.hook_attn_scores(attn_scores) - pattern = F.softmax(attn_scores, dim=-1) - pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) - pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] - pattern = pattern.to(self.cfg.dtype) - z = self.hook_z( - einsum( - "batch key_pos head_index d_head, \ - batch head_index query_pos key_pos -> \ - batch query_pos head_index d_head", - v, - pattern, - ) - ) # [batch, pos, head_index, d_head] - if not self.cfg.use_attn_result: - out = ( - ( - einsum( - "batch pos head_index d_head, \ - head_index d_head d_model -> \ - batch pos d_model", - z, - self.W_O, - ) - ) - + self.b_O - ) # [batch, pos, d_model] - else: - # Explicitly calculate the attention result so it can be accessed by a hook - # This is off by default because it can easily eat through your GPU memory. - result = self.hook_result( - einsum( - "batch pos head_index d_head, \ - head_index d_head d_model -> \ - batch pos head_index d_model", - z, - self.W_O, - ) - ) # [batch, pos, head_index, d_model] - out = ( - einops.reduce( - result, "batch position index model->batch position model", "sum" - ) - + self.b_O - ) # [batch, pos, d_model] - return out - - def apply_causal_mask( - self, - attn_scores: Float[ - torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset" - ], - past_kv_pos_offset: int = 0, - attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, - ): - # The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it can be different. - query_ctx_length = attn_scores.size(-2) - # The key context length is the number of positions in the past - this includes all positions in the cache - # If not caching, query_ctx_length == key_ctx_length - key_ctx_length = attn_scores.size(-1) - - assert ( - query_ctx_length + past_kv_pos_offset == key_ctx_length - ), f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug." - - # Index back to front to ensure local attention works - final_mask = self.mask[ - None, None, -query_ctx_length:, -key_ctx_length: - ] # [1, 1, pos, pos] - if attention_mask is not None: - # Apply a causal mask to the attention scores considering the padding - einsum_str = "batch head pos offset_pos, batch offset_pos -> batch head pos offset_pos" - final_mask = einops.einsum(final_mask, attention_mask, einsum_str).bool() - - return torch.where(final_mask, attn_scores, self.IGNORE) - - def calculate_sin_cos_rotary( - self, - rotary_dim: int, - n_ctx: int, - base: int = 10000, - dtype: torch.dtype = torch.float32, - ) -> Tuple[ - Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"] - ]: - """ - Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details - - Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent. - To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is. - """ - high_precision = torch.float32 if dtype != torch.float64 else torch.float64 - pos = torch.arange(n_ctx, dtype=high_precision) - dim = torch.arange(rotary_dim // 2, dtype=high_precision) - - # A set of frequencies evenly spaced in log space - freq = base ** (dim / (rotary_dim / 2)) - if self.cfg.original_architecture in ["GPTNeoXForCausalLM", "LlamaForCausalLM"]: - freq = einops.repeat(freq, "d -> (2 d)") - else: - freq = einops.repeat(freq, "d -> (d 2)") - # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency - angles = pos[:, None] / freq[None, :] - return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) - - def rotate_every_two( - self, x: Float[torch.Tensor, "... rotary_dim"] - ) -> Float[torch.Tensor, "... rotary_dim"]: - """ - Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0] - - The final axis of x must have even length. - - GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details. - """ - rot_x = x.clone() - if self.cfg.original_architecture in ["GPTNeoXForCausalLM", "LlamaForCausalLM"]: - n = x.size(-1) // 2 - rot_x[..., :n] = -x[..., n:] - rot_x[..., n:] = x[..., :n] - else: - rot_x[..., ::2] = -x[..., 1::2] - rot_x[..., 1::2] = x[..., ::2] - - return rot_x - - def apply_rotary( - self, - x: Float[torch.Tensor, "batch pos head_index d_head"], - past_kv_pos_offset=0, - attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, - ) -> Float[torch.Tensor, "batch pos head_index d_head"]: - # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions) - x_pos = x.size(1) - x_rot = x[..., : self.cfg.rotary_dim] - x_pass = x[..., self.cfg.rotary_dim :] - x_flip = self.rotate_every_two(x_rot) - - if attention_mask is None: - rotary_cos = self.rotary_cos[ - None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : - ] - rotary_sin = self.rotary_sin[ - None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : - ] - x_rotated = x_rot * rotary_cos + x_flip * rotary_sin - else: - offset_position_ids = get_offset_position_ids( - past_kv_pos_offset, attention_mask - ) - mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :] - mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :] - x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin - - return torch.cat([x_rotated, x_pass], dim=-1) - - @staticmethod - def create_alibi_slope( - n_ctx: int, device: torch.device = None - ) -> Float[torch.Tensor, "query key"]: - """Create an ALiBi Slope Matrix. - - Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar. - - See :meth:`create_alibi_bias` for the full ALiBi bias calculation. - - Examples: - - >>> Attention.create_alibi_slope(3) - tensor([[ 0., 0., 0.], - [-1., 0., 0.], - [-2., -1., 0.]]) - - >>> Attention.create_alibi_slope(4) - tensor([[ 0., 0., 0., 0.], - [-1., 0., 0., 0.], - [-2., -1., 0., 0.], - [-3., -2., -1., 0.]]) - - Args: - n_ctx: The maximum number of tokens in a prompt. - - Returns: - A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower - triangle is decreasing by a constant slope of 1 (towards the bottom left corner). - """ - # set rows as [[0,1,2...]] - rows = torch.arange(n_ctx, device=device).unsqueeze(0) - - # Set cols as [[0],[1],[2]...] - cols = torch.arange(n_ctx, device=device).unsqueeze(1) - - # Use broadcasting to create the desired lower triangular part of the matrix - slope_matrix = rows - cols - - # Use the clamp method to set all positive values (upper right triangle) to - return slope_matrix.clamp(max=0).to(torch.float32) - - @staticmethod - def create_alibi_multipliers( - n_heads: int, device: torch.device = None - ) -> Float[torch.Tensor, "head_idx"]: - """Create the ALiBi Scalar Multipliers for each Head. - - For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and - uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1), - 1/(2^2), ... , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), ... , 1/(2^8)]. - - See :meth:`create_alibi_bias` for the full ALiBi bias calculation. - - Examples: - - >>> Attention.create_alibi_multipliers(8) - tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039]) - - >>> Attention.create_alibi_multipliers(16) - tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312, - 0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039]) - - Args: - n_heads: The number of heads in a layer. - device: The device to create the tensor on. - - Returns: - A tensor of shape (n_heads,) containing the scalar multiplier for each head. - """ - # Calculate the starting value - start = 2 ** (-8 / n_heads) - - # Generate the indices [0, 1, ..., n_heads-1] - indices = torch.arange(n_heads, device=device) - - # Compute the multipliers, with the starting value being the same as the ratio - multipliers = start * (start**indices) - - return multipliers - - @staticmethod - def create_alibi_bias( - n_heads: int, n_ctx: int, device: torch.device = None - ) -> Float[torch.Tensor, "head_idx query key"]: - """Create the ALiBi Bias for all Heads. - - Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer. - - The broad idea behind ALiBi is to remove the positional encoding from the original transformer - model, and instead apply a bias to each attention score. This bias is proportional to the - distance between the query and key (i.e. it encourage paying less attention to more distant - tokens), and is added to the attention scores before the softmax. It is used in models such as - Bloom. - - Examples: - - >>> Attention.create_alibi_bias(2, 4, torch.device('cpu')) - tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000], - [-0.0625, 0.0000, 0.0000, 0.0000], - [-0.1250, -0.0625, 0.0000, 0.0000], - [-0.1875, -0.1250, -0.0625, 0.0000]], - [[ 0.0000, 0.0000, 0.0000, 0.0000], - [-0.0039, 0.0000, 0.0000, 0.0000], - [-0.0078, -0.0039, 0.0000, 0.0000], - [-0.0117, -0.0078, -0.0039, 0.0000]]]) - - Args: - n_heads: The number of heads in a layer. - n_ctx: The maximum number of tokens in a prompt. - device: The device to create the tensor on. - - Returns: - The ALiBi bias that should be added to the attention scores before the softmax. - """ - # Create the slope matrix - slope: Float[torch.Tensor, "query key"] = Attention.create_alibi_slope( - n_ctx, device - ) - - # Create the scalar multiplier for each head. - multipliers: Float[ - torch.Tensor, "head_idx" - ] = Attention.create_alibi_multipliers(n_heads, device) - - # The ALiBi bias is then m * slope_matrix - alibi_bias = torch.einsum("ij,k->kij", slope, multipliers) - - return alibi_bias - - -# MLP Layers -class MLP(nn.Module): - def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): - super().__init__() - if isinstance(cfg, Dict): - cfg = HookedTransformerConfig.from_dict(cfg) - self.cfg = cfg - self.W_in = nn.Parameter( - torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype) - ) - self.b_in = nn.Parameter(torch.zeros(self.cfg.d_mlp, dtype=cfg.dtype)) - self.W_out = nn.Parameter( - torch.empty(self.cfg.d_mlp, self.cfg.d_model, dtype=cfg.dtype) - ) - self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype)) - - self.hook_pre = HookPoint() # [batch, pos, d_mlp] - self.hook_post = HookPoint() # [batch, pos, d_mlp] - - if self.cfg.act_fn == "relu": - self.act_fn = F.relu - elif self.cfg.act_fn == "gelu": - self.act_fn = F.gelu - elif self.cfg.act_fn == "silu": - self.act_fn = F.silu - elif self.cfg.act_fn == "gelu_new": - self.act_fn = gelu_new - elif self.cfg.act_fn == "gelu_fast": - self.act_fn = gelu_fast - elif self.cfg.act_fn == "solu_ln": - self.act_fn = solu - # Hook taken between activation and layer norm - self.hook_mid = HookPoint() # [batch, pos, d_mlp] - if self.cfg.normalization_type == "LN": - self.ln = LayerNorm(self.cfg, self.cfg.d_mlp) - else: - self.ln = LayerNormPre(self.cfg) - - else: - raise ValueError(f"Invalid activation function name: {self.cfg.act_fn}") - - def forward( - self, x: Float[torch.Tensor, "batch pos d_model"] - ) -> Float[torch.Tensor, "batch pos d_model"]: - # Technically, all these einsums could be done with a single matmul, but this is more readable. - pre_act = self.hook_pre( - einsum("batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in) - + self.b_in - ) # [batch, pos, d_mlp] - if not self.cfg.act_fn.endswith("_ln"): - post_act = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp] - else: - mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] - post_act = self.hook_post(self.ln(mid_act)) - return ( - einsum( - "batch pos d_mlp, d_mlp d_model -> batch pos d_model", - post_act, - self.W_out, - ) - + self.b_out - ) - - -# TODO -# not sure whether to fold this into MLP or not -class GatedMLP(nn.Module): - """ - The equation of a gated MLP: - pre = x @ W_gate - pre_linear = x @ W_in - post = Gelu(pre) * (pre_linear) + b_in - mlp_out = post @ W_out + b_out - - In one equation, mlp_out = (Gelu(x @ W_gate) * (x @ W_in) + b_in) @ W_out + b_out - """ - - def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): - super().__init__() - if isinstance(cfg, Dict): - cfg = HookedTransformerConfig.from_dict(cfg) - self.cfg = cfg - self.W_in = nn.Parameter( - torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype) - ) - self.W_gate = nn.Parameter( - torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype) - ) - self.b_in = nn.Parameter(torch.zeros(self.cfg.d_mlp, dtype=cfg.dtype)) - self.W_out = nn.Parameter( - torch.empty(self.cfg.d_mlp, self.cfg.d_model, dtype=cfg.dtype) - ) - self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype)) - - # hook on gate output but before act_fn - self.hook_pre = HookPoint() # [batch, pos, d_mlp] - # hook on the linear component of the input - self.hook_pre_linear = HookPoint() # [batch, pos, d_mlp] - # hook on act_fn(gate_output) * W_in(x) + b_in - self.hook_post = HookPoint() # [batch, pos, d_mlp] - - if self.cfg.act_fn == "relu": - self.act_fn = F.relu - elif self.cfg.act_fn == "gelu": - self.act_fn = F.gelu - elif self.cfg.act_fn == "silu": - self.act_fn = F.silu - elif self.cfg.act_fn == "gelu_new": - self.act_fn = gelu_new - elif self.cfg.act_fn == "gelu_fast": - self.act_fn = gelu_fast - elif self.cfg.act_fn == "solu_ln": - self.act_fn = solu - # Hook taken between activation and layer norm - self.hook_mid = HookPoint() # [batch, pos, d_mlp] - if self.cfg.normalization_type == "LN": - self.ln = LayerNorm(self.cfg, self.cfg.d_mlp) - else: - self.ln = LayerNormPre(self.cfg) - - else: - raise ValueError(f"Invalid activation function name: {self.cfg.act_fn}") - - def forward( - self, x: Float[torch.Tensor, "batch pos d_model"] - ) -> Float[torch.Tensor, "batch pos d_model"]: - # Technically, all these einsums could be done with a single matmul, but this is more readable. - pre_act = self.hook_pre( - einsum( - "batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_gate - ) - ) # [batch, pos, d_mlp] - if not self.cfg.act_fn.endswith("_ln"): - pre_linear = self.hook_pre_linear( - einsum( - "batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in - ) - ) - post_act = self.hook_post( - (self.act_fn(pre_act) * pre_linear) + self.b_in - ) # [batch, pos, d_mlp] - else: - mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] - post_act = self.hook_post(self.ln(mid_act)) - return ( - einsum( - "batch pos d_mlp, d_mlp d_model -> batch pos d_model", - post_act, - self.W_out, - ) - + self.b_out - ) - - -# Transformer Block -class TransformerBlock(nn.Module): - def __init__(self, cfg: Union[Dict, HookedTransformerConfig], block_index): - super().__init__() - if isinstance(cfg, Dict): - cfg = HookedTransformerConfig.from_dict(cfg) - self.cfg = cfg - if self.cfg.normalization_type == "LN": - self.ln1 = LayerNorm(cfg) - if not self.cfg.attn_only: - self.ln2 = LayerNorm(cfg) - elif self.cfg.normalization_type == "LNPre": - # We've folded in LayerNorm weights, so just need the center + scale parts - self.ln1 = LayerNormPre(cfg) - if not self.cfg.attn_only: - self.ln2 = LayerNormPre(cfg) - elif self.cfg.normalization_type == "RMS": - self.ln1 = RMSNorm(cfg) - if not self.cfg.attn_only: - self.ln2 = RMSNorm(cfg) - elif self.cfg.normalization_type == "RMSPre": - self.ln1 = RMSNormPre(cfg) - if not self.cfg.attn_only: - self.ln2 = RMSNormPre(cfg) - elif self.cfg.normalization_type is None: - self.ln1 = nn.Identity() - if not self.cfg.attn_only: - self.ln2 = nn.Identity() - else: - logging.warning( - f"Invalid normalization_type passed in {self.cfg.normalization_type}" - ) - - if not self.cfg.use_local_attn: - self.attn = Attention(cfg, "global", block_index) - else: - assert self.cfg.attn_types is not None - attn_type = self.cfg.attn_types[block_index] - self.attn = Attention(cfg, attn_type, block_index) - if not self.cfg.attn_only: - if self.cfg.gated_mlp: - self.mlp = GatedMLP(cfg) - else: - self.mlp = MLP(cfg) - - self.hook_attn_in = HookPoint() # [batch, pos, n_heads, d_model] - self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model] - self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model] - self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model] - self.hook_mlp_in = HookPoint() # [batch, pos, d_model] - - self.hook_attn_out = HookPoint() # [batch, pos, d_model] - self.hook_mlp_out = HookPoint() # [batch, pos, d_model] - - self.hook_resid_pre = HookPoint() # [batch, pos, d_model] - if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp: - self.hook_resid_mid = HookPoint() # [batch, pos, d_model] - self.hook_resid_post = HookPoint() # [batch, pos, d_model] - - def forward( - self, - resid_pre: Float[torch.Tensor, "batch pos d_model"], - shortformer_pos_embed: Optional[ - Float[torch.Tensor, "batch pos d_model"] - ] = None, - past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, - attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, - ) -> Float[torch.Tensor, "batch pos d_model"]: - """A single Transformer block. - - Args: - resid_pre (torch.Tensor): The residual stream - shape [batch, pos, d_model] - cache (HookedTransformerKeyValueCache): A cache of previous keys and values, used only when generating text. Defaults to None. - shortformer_pos_embed (torch.Tensor, optional): Only used for positional_embeddings_type == "shortformer". The positional embeddings. See HookedTransformerConfig for details. Defaults to None. - attention_mask (torch.Tensor, optional): The attention mask for padded tokens. Defaults to None. - - Returns: - _type_: _description_ - """ - resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model] - - def add_head_dimension( - tensor: Float[torch.Tensor, "batch pos d_model"], - clone_tensor=True, - # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry - ): - repeated_tensor = einops.repeat( - tensor, - "batch pos d_model -> batch pos n_heads d_model", - n_heads=self.cfg.n_heads, - ) - if clone_tensor: - return repeated_tensor.clone() - else: - return repeated_tensor - - if self.cfg.use_attn_in or self.cfg.use_split_qkv_input: - # We're adding a head dimension - attn_in = add_head_dimension(resid_pre, clone_tensor=False) - if shortformer_pos_embed is not None: - shortformer_pos_embed = add_head_dimension(shortformer_pos_embed) - else: - attn_in = resid_pre - - if self.cfg.use_attn_in: - attn_in = self.hook_attn_in(attn_in.clone()) - - if self.cfg.use_split_qkv_input: - query_input = self.hook_q_input(attn_in.clone()) - key_input = self.hook_k_input(attn_in.clone()) - value_input = self.hook_v_input(attn_in.clone()) - else: - query_input = attn_in - key_input = attn_in - value_input = attn_in - - attn_out = self.hook_attn_out( - # hook the residual stream states that are used to calculate the - # queries, keys and values, independently. - # Then take the layer norm of these inputs, and pass these to the attention module. - self.attn( - query_input=self.ln1(query_input) - + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed), - key_input=self.ln1(key_input) - + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed), - value_input=self.ln1(value_input), - past_kv_cache_entry=past_kv_cache_entry, - attention_mask=attention_mask, - ) - ) # [batch, pos, d_model] - if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp: - resid_mid = self.hook_resid_mid( - resid_pre + attn_out - ) # [batch, pos, d_model] - mlp_in = ( - resid_mid - if not self.cfg.use_hook_mlp_in - else self.hook_mlp_in(resid_mid.clone()) - ) - normalized_resid_mid = self.ln2(mlp_in) - mlp_out = self.hook_mlp_out( - self.mlp(normalized_resid_mid) - ) # [batch, pos, d_model] - resid_post = self.hook_resid_post( - resid_mid + mlp_out - ) # [batch, pos, d_model] - elif self.cfg.parallel_attn_mlp: - # Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used. - # In GPT-J, LN1 and LN2 are tied, in GPT-NeoX they aren't. - normalized_resid_pre_2 = self.ln2( - resid_pre - if not self.cfg.use_hook_mlp_in - else self.hook_mlp_in(resid_pre.clone()) - ) - mlp_out = self.hook_mlp_out( - self.mlp(normalized_resid_pre_2) - ) # [batch, pos, d_model] - resid_post = self.hook_resid_post( - resid_pre + attn_out + mlp_out - ) # [batch, pos, d_model] - else: - resid_post = self.hook_resid_post( - resid_pre + attn_out - ) # [batch, pos, d_model] - return resid_post - - -class BertBlock(nn.Module): - """ - BERT Block. Similar to the TransformerBlock, except that the LayerNorms are applied after the attention and MLP, rather than before. - """ - - def __init__(self, cfg: HookedTransformerConfig): - super().__init__() - self.cfg = cfg - - self.attn = Attention(cfg) - self.ln1 = LayerNorm(cfg) - self.mlp = MLP(cfg) - self.ln2 = LayerNorm(cfg) - - self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model] - self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model] - self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model] - - self.hook_attn_out = HookPoint() # [batch, pos, d_model] - self.hook_mlp_in = HookPoint() # [batch, pos, d_model] - self.hook_mlp_out = HookPoint() # [batch, pos, d_model] - self.hook_resid_pre = HookPoint() # [batch, pos, d_model] - self.hook_resid_mid = HookPoint() # [batch, pos, d_model] - self.hook_resid_post = HookPoint() # [batch, pos, d_model] - self.hook_normalized_resid_post = HookPoint() # [batch, pos, d_model] - - def forward( - self, - resid_pre: Float[torch.Tensor, "batch pos d_model"], - additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None, - ): - resid_pre = self.hook_resid_pre(resid_pre) - - query_input = resid_pre - key_input = resid_pre - value_input = resid_pre - - if self.cfg.use_split_qkv_input: - - def add_head_dimension(tensor): - return einops.repeat( - tensor, - "batch pos d_model -> batch pos n_heads d_model", - n_heads=self.cfg.n_heads, - ).clone() - - query_input = self.hook_q_input(add_head_dimension(query_input)) - key_input = self.hook_k_input(add_head_dimension(key_input)) - value_input = self.hook_v_input(add_head_dimension(value_input)) - - attn_out = self.hook_attn_out( - self.attn( - query_input, - key_input, - value_input, - additive_attention_mask=additive_attention_mask, - ) - ) - resid_mid = self.hook_resid_mid(resid_pre + attn_out) - - mlp_in = ( - resid_mid - if not self.cfg.use_hook_mlp_in - else self.hook_mlp_in(resid_mid.clone()) - ) - normalized_resid_mid = self.ln1(mlp_in) - mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid)) - resid_post = self.hook_resid_post(normalized_resid_mid + mlp_out) - normalized_resid_post = self.hook_normalized_resid_post(self.ln2(resid_post)) - - return normalized_resid_post diff --git a/transformer_lens/components/BertEmbed.py b/transformer_lens/components/BertEmbed.py new file mode 100644 index 000000000..dfdde0bd2 --- /dev/null +++ b/transformer_lens/components/BertEmbed.py @@ -0,0 +1,62 @@ +"""Hooked Transformer Components. + +This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) +needed to create many different types of generative language models. They are used by +:class:`transformer_lens.HookedTransformer`. +""" +from .embed import Embed +from .layer_norm import LayerNorm +from .pos_embed import PosEmbed +from .token_typed_embed import TokenTypeEmbed +import einops +from jaxtyping import Int +import torch +import torch.nn as nn +from transformer_lens.hook_points import HookPoint +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from typing import Dict, Optional, Union + + +class BertEmbed(nn.Module): + """ + Custom embedding layer for a BERT-like model. This module computes the sum of the token, positional and token-type embeddings and takes the layer norm of the result. + """ + + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.embed = Embed(cfg) + self.pos_embed = PosEmbed(cfg) + self.token_type_embed = TokenTypeEmbed(cfg) + self.ln = LayerNorm(cfg) + + self.hook_embed = HookPoint() + self.hook_pos_embed = HookPoint() + self.hook_token_type_embed = HookPoint() + + def forward( + self, + input_ids: Int[torch.Tensor, "batch pos"], + token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, + ): + base_index_id = torch.arange(input_ids.shape[1], device=input_ids.device) + index_ids = einops.repeat( + base_index_id, "pos -> batch pos", batch=input_ids.shape[0] + ) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + word_embeddings_out = self.hook_embed(self.embed(input_ids)) + position_embeddings_out = self.hook_pos_embed(self.pos_embed(index_ids)) + token_type_embeddings_out = self.hook_token_type_embed( + self.token_type_embed(token_type_ids) + ) + + embeddings_out = ( + word_embeddings_out + position_embeddings_out + token_type_embeddings_out + ) + layer_norm_out = self.ln(embeddings_out) + return layer_norm_out + diff --git a/transformer_lens/components/__init__.py b/transformer_lens/components/__init__.py new file mode 100644 index 000000000..14bb14085 --- /dev/null +++ b/transformer_lens/components/__init__.py @@ -0,0 +1,15 @@ +from .attention import Attention +from .bert_block import BertBlock +from .bert_embed import BertEmbed +from .bert_mlm_head import BertMLMHead +from .embed import Embed +from .gated_mlp import GatedMLP +from .layer_norm import LayerNorm +from .layer_norm_pre import LayerNormPre +from .mlp import MLP +from .pos_embed import PosEmbed +from .rms_norm import RMSNorm +from .rms_norm_pre import RMSNormPre +from .token_typed_embed import TokenTypeEmbed +from .transformer_block import TransformerBlock +from .unembed import Unembed \ No newline at end of file diff --git a/transformer_lens/components/attention.py b/transformer_lens/components/attention.py new file mode 100644 index 000000000..04c8eff4c --- /dev/null +++ b/transformer_lens/components/attention.py @@ -0,0 +1,566 @@ +"""Hooked Transformer Components. + +This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) +needed to create many different types of generative language models. They are used by +:class:`transformer_lens.HookedTransformer`. +""" + +import einops +from fancy_einsum import einsum +from jaxtyping import Float, Int +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformer_lens.FactoredMatrix import FactoredMatrix +from transformer_lens.hook_points import HookPoint +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry +from transformer_lens.utils import get_offset_position_ids +from typing import Dict, Optional, Tuple, Union + + +# Attention +class Attention(nn.Module): + def __init__( + self, + cfg: Union[Dict, HookedTransformerConfig], + attn_type: str = "global", + layer_id: Optional[int] = None, + ): + """Attention Block - params have shape [head_index, d_model, d_head] (or [head_index, d_head, d_model] for W_O) and multiply on the right. attn_scores refers to query key dot product immediately before attention softmax + + Convention: All attention pattern-style matrices have shape [batch, head_index, query_pos, key_pos] + + Args: + cfg (Union[Dict, HookedTransformerConfig]): Config + attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global". + layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. + """ + super().__init__() + + self.cfg = HookedTransformerConfig.from_dict(cfg) if isinstance(cfg, Dict) else cfg + + self.W_Q = nn.Parameter( + torch.empty( + self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype + ) + ) + self.W_K = nn.Parameter( + torch.empty( + self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype + ) + ) + self.W_V = nn.Parameter( + torch.empty( + self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype + ) + ) + self.W_O = nn.Parameter( + torch.empty( + self.cfg.n_heads, self.cfg.d_head, self.cfg.d_model, dtype=cfg.dtype + ) + ) + self.b_Q = nn.Parameter( + torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype) + ) + self.b_K = nn.Parameter( + torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype) + ) + self.b_V = nn.Parameter( + torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype) + ) + self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype)) + + # Create a max_ctx x max_ctx mask, with True iff that query position + # can attend to that key position (query is first axis, key is second axis) + causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool()) + + if attn_type == "global": + # For global attention, this is a lower triangular matrix - key <= query + self.register_buffer("mask", causal_mask) + elif attn_type == "local": + # For local, this is banded, query - window_size < key <= query + assert isinstance(self.cfg.window_size, int) + self.register_buffer( + "mask", torch.triu(causal_mask, 1 - self.cfg.window_size) + ) + else: + raise ValueError(f"Invalid attention type: {attn_type}") + + self.register_buffer("IGNORE", torch.tensor(-torch.inf)) + + self.layer_id = layer_id + + # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability? + self.attn_scale = np.sqrt(self.cfg.d_head) if self.cfg.use_attn_scale else 1.0 + + if self.cfg.scale_attn_by_inverse_layer_idx: + self.attn_scale *= self.layer_id + 1 + + self.hook_k = HookPoint() # [batch, pos, head_index, d_head] + self.hook_q = HookPoint() # [batch, pos, head_index, d_head] + self.hook_v = HookPoint() # [batch, pos, head_index, d_head] + self.hook_z = HookPoint() # [batch, pos, head_index, d_head] + self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] + self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos] + self.hook_result = HookPoint() # [batch, pos, head_index, d_model] + + # See HookedTransformerConfig for more details. + if self.cfg.positional_embedding_type == "shortformer": + # This tracks the input to the keys and queries, which is resid_pre + pos_embeds + self.hook_attn_input = HookPoint() # [batch, pos, d_model] + elif self.cfg.positional_embedding_type == "rotary": + # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details + self.hook_rot_k = HookPoint() + self.hook_rot_q = HookPoint() + sin, cos = self.calculate_sin_cos_rotary( + self.cfg.rotary_dim, self.cfg.n_ctx, dtype=self.cfg.dtype + ) + self.register_buffer("rotary_sin", sin) + self.register_buffer("rotary_cos", cos) + + @property + def OV(self) -> FactoredMatrix: + """ + OV-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more) + + Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry! + + Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works. + """ + return FactoredMatrix(self.W_V, self.W_O) + + @property + def QK(self) -> FactoredMatrix: + """ + QK-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more). + + Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos] + + Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works. + """ + W_K_transpose = einops.rearrange( + self.W_K, "head_index d_model d_head -> head_index d_head d_model" + ) + return FactoredMatrix(self.W_Q, W_K_transpose) + + def forward( + self, + query_input: Union[ + Float[torch.Tensor, "batch pos d_model"], + Float[torch.Tensor, "batch pos head_index d_model"], + ], + key_input: Union[ + Float[torch.Tensor, "batch pos d_model"], + Float[torch.Tensor, "batch pos head_index d_model"], + ], + value_input: Union[ + Float[torch.Tensor, "batch pos d_model"], + Float[torch.Tensor, "batch pos head_index d_model"], + ], + past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, + additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None, + attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, + ) -> Float[torch.Tensor, "batch pos d_model"]: + """ + shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details + past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None + additive_attention_mask is an optional mask to add to the attention weights. Defaults to None. + attention_mask is the attention mask for padded tokens. Defaults to None. + """ + + if self.cfg.use_split_qkv_input or self.cfg.use_attn_in: + qkv_einops_string = "batch pos head_index d_model" + else: + qkv_einops_string = "batch pos d_model" + + q = self.hook_q( + einsum( + f"{qkv_einops_string}, head_index d_model d_head \ + -> batch pos head_index d_head", + query_input, + self.W_Q, + ) + + self.b_Q + ) # [batch, pos, head_index, d_head] + k = self.hook_k( + einsum( + f"{qkv_einops_string}, head_index d_model d_head \ + -> batch pos head_index d_head", + key_input, + self.W_K, + ) + + self.b_K + ) # [batch, pos, head_index, d_head] + v = self.hook_v( + einsum( + f"{qkv_einops_string}, head_index d_model d_head \ + -> batch pos head_index d_head", + value_input, + self.W_V, + ) + + self.b_V + ) # [batch, pos, head_index, d_head] + + if past_kv_cache_entry is not None: + # Appends the new keys and values to the cached values, and automatically updates the cache + kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1) + k, v = past_kv_cache_entry.append(k, v) + else: + # Not using a cache + kv_cache_pos_offset = 0 + + if self.cfg.positional_embedding_type == "rotary": + q = self.hook_rot_q( + self.apply_rotary(q, kv_cache_pos_offset, attention_mask) + ) + k = self.hook_rot_k( + self.apply_rotary(k, 0, attention_mask) + ) # keys are cached so no offset + + if self.cfg.dtype not in [torch.float32, torch.float64]: + # If using 16 bits, increase the precision to avoid numerical instabilities + q = q.to(torch.float32) + k = k.to(torch.float32) + + attn_scores = ( + einsum( + "batch query_pos head_index d_head, \ + batch key_pos head_index d_head \ + -> batch head_index query_pos key_pos", + q, + k, + ) + / self.attn_scale + ) # [batch, head_index, query_pos, key_pos] + + if self.cfg.positional_embedding_type == "alibi": + query_ctx = attn_scores.size(-2) + # The key context length is the number of positions in the past - this includes all positions in the cache + key_ctx = attn_scores.size(-1) + + alibi = self.get_cached_alibi(key_ctx=key_ctx) + + attn_scores += alibi[ + :, :query_ctx, :key_ctx + ] # [batch, head_index, query_pos, key_pos] + + if self.cfg.attention_dir == "causal": + # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. + attn_scores = self.apply_causal_mask( + attn_scores, kv_cache_pos_offset, attention_mask + ) # [batch, head_index, query_pos, key_pos] + + if additive_attention_mask is not None: + attn_scores += additive_attention_mask + + attn_scores = self.hook_attn_scores(attn_scores) + pattern = F.softmax(attn_scores, dim=-1) + pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) + pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] + pattern = pattern.to(self.cfg.dtype) + z = self.hook_z( + einsum( + "batch key_pos head_index d_head, \ + batch head_index query_pos key_pos -> \ + batch query_pos head_index d_head", + v, + pattern, + ) + ) # [batch, pos, head_index, d_head] + + if not self.cfg.use_attn_result: + return ( + ( + einsum( + "batch pos head_index d_head, \ + head_index d_head d_model -> \ + batch pos d_model", + z, + self.W_O, + ) + ) + + self.b_O + ) # [batch, pos, d_model] + else: + # Explicitly calculate the attention result so it can be accessed by a hook + # This is off by default because it can easily eat through your GPU memory. + result = self.hook_result( + einsum( + "batch pos head_index d_head, \ + head_index d_head d_model -> \ + batch pos head_index d_model", + z, + self.W_O, + ) + ) # [batch, pos, head_index, d_model] + return ( + einops.reduce( + result, "batch position index model->batch position model", "sum" + ) + + self.b_O + ) # [batch, pos, d_model] + + + def apply_causal_mask( + self, + attn_scores: Float[ + torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset" + ], + past_kv_pos_offset: int = 0, + attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, + ): + # The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it can be different. + query_ctx_length = attn_scores.size(-2) + # The key context length is the number of positions in the past - this includes all positions in the cache + # If not caching, query_ctx_length == key_ctx_length + key_ctx_length = attn_scores.size(-1) + + assert ( + query_ctx_length + past_kv_pos_offset == key_ctx_length + ), f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug." + + # Index back to front to ensure local attention works + final_mask = self.mask[ + None, None, -query_ctx_length:, -key_ctx_length: + ] # [1, 1, pos, pos] + + if attention_mask is not None: + # Apply a causal mask to the attention scores considering the padding + einsum_str = "batch head pos offset_pos, batch offset_pos -> batch head pos offset_pos" + final_mask = einops.einsum(final_mask, attention_mask, einsum_str).bool() + + return torch.where(final_mask, attn_scores, self.IGNORE) + + def calculate_sin_cos_rotary( + self, + rotary_dim: int, + n_ctx: int, + base: int = 10000, + dtype: torch.dtype = torch.float32, + ) -> Tuple[ + Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"] + ]: + """ + Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details + + Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent. + To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is. + """ + high_precision = torch.float32 if dtype != torch.float64 else torch.float64 + pos = torch.arange(n_ctx, dtype=high_precision) + dim = torch.arange(rotary_dim // 2, dtype=high_precision) + + # A set of frequencies evenly spaced in log space + freq = base ** (dim / (rotary_dim / 2)) + if self.cfg.original_architecture in ["GPTNeoXForCausalLM", "LlamaForCausalLM"]: + freq = einops.repeat(freq, "d -> (2 d)") + else: + freq = einops.repeat(freq, "d -> (d 2)") + + # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency + angles = pos[:, None] / freq[None, :] + + return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) + + def rotate_every_two( + self, x: Float[torch.Tensor, "... rotary_dim"] + ) -> Float[torch.Tensor, "... rotary_dim"]: + """ + Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0] + + The final axis of x must have even length. + + GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details. + """ + rot_x = x.clone() + if self.cfg.original_architecture in ["GPTNeoXForCausalLM", "LlamaForCausalLM"]: + n = x.size(-1) // 2 + rot_x[..., :n] = -x[..., n:] + rot_x[..., n:] = x[..., :n] + else: + rot_x[..., ::2] = -x[..., 1::2] + rot_x[..., 1::2] = x[..., ::2] + + return rot_x + + def apply_rotary( + self, + x: Float[torch.Tensor, "batch pos head_index d_head"], + past_kv_pos_offset=0, + attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, + ) -> Float[torch.Tensor, "batch pos head_index d_head"]: + # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions) + x_pos = x.size(1) + x_rot = x[..., : self.cfg.rotary_dim] + x_pass = x[..., self.cfg.rotary_dim :] + x_flip = self.rotate_every_two(x_rot) + + if attention_mask is None: + rotary_cos = self.rotary_cos[ + None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : + ] + rotary_sin = self.rotary_sin[ + None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : + ] + x_rotated = x_rot * rotary_cos + x_flip * rotary_sin + else: + offset_position_ids = get_offset_position_ids( + past_kv_pos_offset, attention_mask + ) + mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :] + mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :] + x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin + + return torch.cat([x_rotated, x_pass], dim=-1) + + @staticmethod + def create_alibi_slope( + n_ctx: int, device: torch.device = None + ) -> Float[torch.Tensor, "query key"]: + """Create an ALiBi Slope Matrix. + + Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar. + + See :meth:`create_alibi_bias` for the full ALiBi bias calculation. + + Examples: + + >>> Attention.create_alibi_slope(3) + tensor([[ 0., 0., 0.], + [-1., 0., 0.], + [-2., -1., 0.]]) + + >>> Attention.create_alibi_slope(4) + tensor([[ 0., 0., 0., 0.], + [-1., 0., 0., 0.], + [-2., -1., 0., 0.], + [-3., -2., -1., 0.]]) + + Args: + n_ctx: The maximum number of tokens in a prompt. + + Returns: + A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower + triangle is decreasing by a constant slope of 1 (towards the bottom left corner). + """ + # set rows as [[0,1,2...]] + rows = torch.arange(n_ctx, device=device).unsqueeze(0) + + # Set cols as [[0],[1],[2]...] + cols = torch.arange(n_ctx, device=device).unsqueeze(1) + + # Use broadcasting to create the desired lower triangular part of the matrix + slope_matrix = rows - cols + + # Use the clamp method to set all positive values (upper right triangle) to + return slope_matrix.clamp(max=0).to(torch.float32) + + @staticmethod + def create_alibi_multipliers( + n_heads: int, device: torch.device = None + ) -> Float[torch.Tensor, "head_idx"]: + """Create the ALiBi Scalar Multipliers for each Head. + + For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and + uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1), + 1/(2^2), ... , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), ... , 1/(2^8)]. + + See :meth:`create_alibi_bias` for the full ALiBi bias calculation. + + Examples: + + >>> Attention.create_alibi_multipliers(8) + tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039]) + + >>> Attention.create_alibi_multipliers(16) + tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312, + 0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039]) + + Args: + n_heads: The number of heads in a layer. + device: The device to create the tensor on. + + Returns: + A tensor of shape (n_heads,) containing the scalar multiplier for each head. + """ + # Calculate the starting value + start = 2 ** (-8 / n_heads) + + # Generate the indices [0, 1, ..., n_heads-1] + indices = torch.arange(n_heads, device=device) + + # Compute the multipliers, with the starting value being the same as the ratio + multipliers = start * (start**indices) + + return multipliers + + @staticmethod + def create_alibi_bias( + n_heads: int, n_ctx: int, device: torch.device = None + ) -> Float[torch.Tensor, "head_idx query key"]: + """Create the ALiBi Bias for all Heads. + + Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer. + + The broad idea behind ALiBi is to remove the positional encoding from the original transformer + model, and instead apply a bias to each attention score. This bias is proportional to the + distance between the query and key (i.e. it encourage paying less attention to more distant + tokens), and is added to the attention scores before the softmax. It is used in models such as + Bloom. + + Examples: + + >>> Attention.create_alibi_bias(2, 4, torch.device('cpu')) + tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000], + [-0.0625, 0.0000, 0.0000, 0.0000], + [-0.1250, -0.0625, 0.0000, 0.0000], + [-0.1875, -0.1250, -0.0625, 0.0000]], + [[ 0.0000, 0.0000, 0.0000, 0.0000], + [-0.0039, 0.0000, 0.0000, 0.0000], + [-0.0078, -0.0039, 0.0000, 0.0000], + [-0.0117, -0.0078, -0.0039, 0.0000]]]) + + Args: + n_heads: The number of heads in a layer. + n_ctx: The maximum number of tokens in a prompt. + device: The device to create the tensor on. + + Returns: + The ALiBi bias that should be added to the attention scores before the softmax. + """ + # Create the slope matrix + slope: Float[torch.Tensor, "query key"] = Attention.create_alibi_slope( + n_ctx, device + ) + + # Create the scalar multiplier for each head. + multipliers: Float[ + torch.Tensor, "head_idx" + ] = Attention.create_alibi_multipliers(n_heads, device) + + # The ALiBi bias is then m * slope_matrix + alibi_bias = torch.einsum("ij,k->kij", slope, multipliers) + + return alibi_bias + + def get_cached_alibi(self, key_ctx: int) -> Float[torch.Tensor, "head_idx query key"]: + """Get A Cached ALiBi bias For Calculation. + + This function will check for if an instance of our ALiBi bias is currently set. + If the ALiBi bias is not set or if our key context is greater than it's cached size, a new + instance will be initiated. + + The cached ALiBi bias is then returned + + Returns: + The ALiBi bias that should be added to the attention scores before the softmax. + """ + # only recompute when necessary to increase efficiency. + if self.cached_alibi is None or key_ctx > self.cached_alibi.size(-1): + self.cached_alibi = Attention.create_alibi_bias( + self.cfg.n_heads, key_ctx, self.cfg.device + ) + + return self.cached_alibi \ No newline at end of file diff --git a/transformer_lens/components/bert_block.py b/transformer_lens/components/bert_block.py new file mode 100644 index 000000000..0a572b86a --- /dev/null +++ b/transformer_lens/components/bert_block.py @@ -0,0 +1,89 @@ +"""Hooked Transformer Components. + +This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) +needed to create many different types of generative language models. They are used by +:class:`transformer_lens.HookedTransformer`. +""" +from .attention import Attention +from .layer_norm import LayerNorm +from .mlp import MLP +import einops +from jaxtyping import Float +import torch +import torch.nn as nn +from transformer_lens.hook_points import HookPoint +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from typing import Optional + + +class BertBlock(nn.Module): + """ + BERT Block. Similar to the TransformerBlock, except that the LayerNorms are applied after the attention and MLP, rather than before. + """ + + def __init__(self, cfg: HookedTransformerConfig): + super().__init__() + self.cfg = cfg + + self.attn = Attention(cfg) + self.ln1 = LayerNorm(cfg) + self.mlp = MLP(cfg) + self.ln2 = LayerNorm(cfg) + + self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model] + self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model] + self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model] + + self.hook_attn_out = HookPoint() # [batch, pos, d_model] + self.hook_mlp_in = HookPoint() # [batch, pos, d_model] + self.hook_mlp_out = HookPoint() # [batch, pos, d_model] + self.hook_resid_pre = HookPoint() # [batch, pos, d_model] + self.hook_resid_mid = HookPoint() # [batch, pos, d_model] + self.hook_resid_post = HookPoint() # [batch, pos, d_model] + self.hook_normalized_resid_post = HookPoint() # [batch, pos, d_model] + + def forward( + self, + resid_pre: Float[torch.Tensor, "batch pos d_model"], + additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None, + ): + resid_pre = self.hook_resid_pre(resid_pre) + + query_input = resid_pre + key_input = resid_pre + value_input = resid_pre + + if self.cfg.use_split_qkv_input: + + def add_head_dimension(tensor): + return einops.repeat( + tensor, + "batch pos d_model -> batch pos n_heads d_model", + n_heads=self.cfg.n_heads, + ).clone() + + query_input = self.hook_q_input(add_head_dimension(query_input)) + key_input = self.hook_k_input(add_head_dimension(key_input)) + value_input = self.hook_v_input(add_head_dimension(value_input)) + + attn_out = self.hook_attn_out( + self.attn( + query_input, + key_input, + value_input, + additive_attention_mask=additive_attention_mask, + ) + ) + resid_mid = self.hook_resid_mid(resid_pre + attn_out) + + mlp_in = ( + resid_mid + if not self.cfg.use_hook_mlp_in + else self.hook_mlp_in(resid_mid.clone()) + ) + normalized_resid_mid = self.ln1(mlp_in) + mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid)) + resid_post = self.hook_resid_post(normalized_resid_mid + mlp_out) + normalized_resid_post = self.hook_normalized_resid_post(self.ln2(resid_post)) + + return normalized_resid_post diff --git a/transformer_lens/components/bert_embed.py b/transformer_lens/components/bert_embed.py new file mode 100644 index 000000000..dfdde0bd2 --- /dev/null +++ b/transformer_lens/components/bert_embed.py @@ -0,0 +1,62 @@ +"""Hooked Transformer Components. + +This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) +needed to create many different types of generative language models. They are used by +:class:`transformer_lens.HookedTransformer`. +""" +from .embed import Embed +from .layer_norm import LayerNorm +from .pos_embed import PosEmbed +from .token_typed_embed import TokenTypeEmbed +import einops +from jaxtyping import Int +import torch +import torch.nn as nn +from transformer_lens.hook_points import HookPoint +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from typing import Dict, Optional, Union + + +class BertEmbed(nn.Module): + """ + Custom embedding layer for a BERT-like model. This module computes the sum of the token, positional and token-type embeddings and takes the layer norm of the result. + """ + + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.embed = Embed(cfg) + self.pos_embed = PosEmbed(cfg) + self.token_type_embed = TokenTypeEmbed(cfg) + self.ln = LayerNorm(cfg) + + self.hook_embed = HookPoint() + self.hook_pos_embed = HookPoint() + self.hook_token_type_embed = HookPoint() + + def forward( + self, + input_ids: Int[torch.Tensor, "batch pos"], + token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, + ): + base_index_id = torch.arange(input_ids.shape[1], device=input_ids.device) + index_ids = einops.repeat( + base_index_id, "pos -> batch pos", batch=input_ids.shape[0] + ) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + word_embeddings_out = self.hook_embed(self.embed(input_ids)) + position_embeddings_out = self.hook_pos_embed(self.pos_embed(index_ids)) + token_type_embeddings_out = self.hook_token_type_embed( + self.token_type_embed(token_type_ids) + ) + + embeddings_out = ( + word_embeddings_out + position_embeddings_out + token_type_embeddings_out + ) + layer_norm_out = self.ln(embeddings_out) + return layer_norm_out + diff --git a/transformer_lens/components/bert_mlm_head.py b/transformer_lens/components/bert_mlm_head.py new file mode 100644 index 000000000..315447906 --- /dev/null +++ b/transformer_lens/components/bert_mlm_head.py @@ -0,0 +1,42 @@ +"""Hooked Transformer Components. + +This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) +needed to create many different types of generative language models. They are used by +:class:`transformer_lens.HookedTransformer`. +""" +from .layer_norm import LayerNorm +from fancy_einsum import einsum +from jaxtyping import Float +import torch +import torch.nn as nn +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from typing import Dict, Union + + +class BertMLMHead(nn.Module): + """ + Transforms BERT embeddings into logits. The purpose of this module is to predict masked tokens in a sentence. + """ + + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.W = nn.Parameter(torch.empty(cfg.d_model, cfg.d_model, dtype=cfg.dtype)) + self.b = nn.Parameter(torch.zeros(cfg.d_model, dtype=cfg.dtype)) + self.act_fn = nn.GELU() + self.ln = LayerNorm(cfg) + + def forward(self, resid: Float[torch.Tensor, "batch pos d_model"]) -> torch.Tensor: + resid = ( + einsum( + "batch pos d_model_in, d_model_out d_model_in -> batch pos d_model_out", + resid, + self.W, + ) + + self.b + ) + resid = self.act_fn(resid) + resid = self.ln(resid) + return resid diff --git a/transformer_lens/components/embed.py b/transformer_lens/components/embed.py new file mode 100644 index 000000000..62eb827b2 --- /dev/null +++ b/transformer_lens/components/embed.py @@ -0,0 +1,35 @@ +"""Hooked Transformer Embed Component. + +This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) +needed to create many different types of generative language models. They are used by +:class:`transformer_lens.HookedTransformer`. +""" +from .layer_norm import LayerNorm +from jaxtyping import Float, Int +import torch +import torch.nn as nn +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from typing import Dict, Union + +# Embed & Unembed +class Embed(nn.Module): + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.W_E: Float[torch.Tensor, "d_vocab d_model"] = nn.Parameter( + torch.empty(self.cfg.d_vocab, self.cfg.d_model, dtype=cfg.dtype) + ) + # Some models (e.g. Bloom) need post embedding layer norm + if cfg.post_embedding_ln: + self.ln = LayerNorm(cfg) + + def forward( + self, tokens: Int[torch.Tensor, "batch pos"] + ) -> Float[torch.Tensor, "batch pos d_model"]: + # If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d] + # B acts as a tensor of indices into the second dimension (so >=0 and Float[torch.Tensor, "batch pos d_model"]: + # Technically, all these einsums could be done with a single matmul, but this is more readable. + pre_act = self.hook_pre( + einsum( + "batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_gate + ) + ) # [batch, pos, d_mlp] + if not self.cfg.act_fn.endswith("_ln"): + pre_linear = self.hook_pre_linear( + einsum( + "batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in + ) + ) + post_act = self.hook_post( + (self.act_fn(pre_act) * pre_linear) + self.b_in + ) # [batch, pos, d_mlp] + else: + mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] + post_act = self.hook_post(self.ln(mid_act)) + return ( + einsum( + "batch pos d_mlp, d_mlp d_model -> batch pos d_model", + post_act, + self.W_out, + ) + + self.b_out + ) diff --git a/transformer_lens/components/layer_norm.py b/transformer_lens/components/layer_norm.py new file mode 100644 index 000000000..83841d93e --- /dev/null +++ b/transformer_lens/components/layer_norm.py @@ -0,0 +1,57 @@ +"""Hooked Transformer Components. + +This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) +needed to create many different types of generative language models. They are used by +:class:`transformer_lens.HookedTransformer`. +""" +from jaxtyping import Float +import torch +import torch.nn as nn +from transformer_lens.hook_points import HookPoint +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from typing import Dict, Optional, Union + + +class LayerNorm(nn.Module): + def __init__( + self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None + ): + """ + LayerNorm with optional length parameter + + length (Optional[int]): If the dimension of the LayerNorm. If not provided, assumed to be d_model + """ + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.eps = self.cfg.eps + self.length = self.cfg.d_model if length is None else length + + self.w = nn.Parameter(torch.ones(self.length, dtype=cfg.dtype)) + self.b = nn.Parameter(torch.zeros(self.length, dtype=cfg.dtype)) + + # Adds a hook point for the normalisation scale factor + self.hook_scale = HookPoint() # [batch, pos, 1] + # Hook_normalized is on the LN output + self.hook_normalized = HookPoint() # [batch, pos, length] + + def forward( + self, + x: Union[ + Float[torch.Tensor, "batch pos d_model"], + Float[torch.Tensor, "batch pos head_index d_model"], + ], + ) -> Union[ + Float[torch.Tensor, "batch pos d_model"], + Float[torch.Tensor, "batch pos head_index d_model"], + ]: + if self.cfg.dtype not in [torch.float32, torch.float64]: + x = x.to(torch.float32) + + x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, length] + scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( + (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() + ) + x = x / scale # [batch, pos, length] + return self.hook_normalized(x * self.w + self.b).to(self.cfg.dtype) \ No newline at end of file diff --git a/transformer_lens/components/layer_norm_pre.py b/transformer_lens/components/layer_norm_pre.py new file mode 100644 index 000000000..00c7eae2b --- /dev/null +++ b/transformer_lens/components/layer_norm_pre.py @@ -0,0 +1,55 @@ +"""Hooked Transformer Components. + +This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) +needed to create many different types of generative language models. They are used by +:class:`transformer_lens.HookedTransformer`. +""" +from jaxtyping import Float +import torch +import torch.nn as nn +from transformer_lens.hook_points import HookPoint +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from typing import Dict, Union + + +# LayerNormPre +# I fold the LayerNorm weights and biases into later weights and biases. +# This is just the 'center and normalise' part of LayerNorm +# Centering is equivalent to just deleting one direction of residual space, +# and is equivalent to centering the weight matrices of everything writing to the residual stream +# Normalising is a funkier non-linear operation, that projects the residual stream onto the unit hypersphere +class LayerNormPre(nn.Module): + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + """LayerNormPre - the 'center and normalise' part of LayerNorm. Length is + normally d_model, but is d_mlp for softmax. Not needed as a parameter. This + should only be used in inference mode after folding in LayerNorm weights""" + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.eps = self.cfg.eps + + # Adds a hook point for the normalisation scale factor + self.hook_scale = HookPoint() # [batch, pos] + # Hook Normalized captures LN output - here it's a vector with std 1 and mean 0 + self.hook_normalized = HookPoint() # [batch, pos, length] + + def forward( + self, + x: Union[ + Float[torch.Tensor, "batch pos d_model"], + Float[torch.Tensor, "batch pos head_index d_model"], + ], + ) -> Union[ + Float[torch.Tensor, "batch pos d_model"], + Float[torch.Tensor, "batch pos head_index d_model"], + ]: + if self.cfg.dtype not in [torch.float32, torch.float64]: + x = x.to(torch.float32) + + x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, length] + scale: Union[ + Float[torch.Tensor, "batch pos 1"], + Float[torch.Tensor, "batch pos head_index 1"], + ] = self.hook_scale((x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()) + return self.hook_normalized(x / scale).to(self.cfg.dtype) \ No newline at end of file diff --git a/transformer_lens/components/mlp.py b/transformer_lens/components/mlp.py new file mode 100644 index 000000000..d57fff224 --- /dev/null +++ b/transformer_lens/components/mlp.py @@ -0,0 +1,81 @@ +"""Hooked Transformer Components. + +This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) +needed to create many different types of generative language models. They are used by +:class:`transformer_lens.HookedTransformer`. +""" +from .layer_norm import LayerNorm +from .layer_norm_pre import LayerNormPre +from fancy_einsum import einsum +from jaxtyping import Float +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformer_lens.hook_points import HookPoint +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.utils import gelu_fast, gelu_new, solu +from typing import Dict, Union + + +# MLP Layers +class MLP(nn.Module): + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.W_in = nn.Parameter( + torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype) + ) + self.b_in = nn.Parameter(torch.zeros(self.cfg.d_mlp, dtype=cfg.dtype)) + self.W_out = nn.Parameter( + torch.empty(self.cfg.d_mlp, self.cfg.d_model, dtype=cfg.dtype) + ) + self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype)) + + self.hook_pre = HookPoint() # [batch, pos, d_mlp] + self.hook_post = HookPoint() # [batch, pos, d_mlp] + + if self.cfg.act_fn == "relu": + self.act_fn = F.relu + elif self.cfg.act_fn == "gelu": + self.act_fn = F.gelu + elif self.cfg.act_fn == "silu": + self.act_fn = F.silu + elif self.cfg.act_fn == "gelu_new": + self.act_fn = gelu_new + elif self.cfg.act_fn == "gelu_fast": + self.act_fn = gelu_fast + elif self.cfg.act_fn == "solu_ln": + self.act_fn = solu + # Hook taken between activation and layer norm + self.hook_mid = HookPoint() # [batch, pos, d_mlp] + if self.cfg.normalization_type == "LN": + self.ln = LayerNorm(self.cfg, self.cfg.d_mlp) + else: + self.ln = LayerNormPre(self.cfg) + + else: + raise ValueError(f"Invalid activation function name: {self.cfg.act_fn}") + + def forward( + self, x: Float[torch.Tensor, "batch pos d_model"] + ) -> Float[torch.Tensor, "batch pos d_model"]: + # Technically, all these einsums could be done with a single matmul, but this is more readable. + pre_act = self.hook_pre( + einsum("batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in) + + self.b_in + ) # [batch, pos, d_mlp] + if not self.cfg.act_fn.endswith("_ln"): + post_act = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp] TODO segmentation fault + else: + mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] + post_act = self.hook_post(self.ln(mid_act)) + return ( + einsum( + "batch pos d_mlp, d_mlp d_model -> batch pos d_model", + post_act, + self.W_out, + ) + + self.b_out + ) \ No newline at end of file diff --git a/transformer_lens/components/pos_embed.py b/transformer_lens/components/pos_embed.py new file mode 100644 index 000000000..62ce9992f --- /dev/null +++ b/transformer_lens/components/pos_embed.py @@ -0,0 +1,73 @@ +"""Hooked Transformer Embed Component. + +This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) +needed to create many different types of generative language models. They are used by +:class:`transformer_lens.HookedTransformer`. +""" + +import einops +from jaxtyping import Float, Int +import torch +import torch.nn as nn +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.utils import get_offset_position_ids +from typing import Dict, Optional, Union + +# Positional Embeddings +class PosEmbed(nn.Module): + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.W_pos = nn.Parameter( + torch.empty(self.cfg.n_ctx, self.cfg.d_model, dtype=cfg.dtype) + ) + + def forward( + self, + tokens: Int[torch.Tensor, "batch pos"], + past_kv_pos_offset: int = 0, + attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, + ) -> Float[torch.Tensor, "batch pos d_model"]: + """ + Forward pass for positional embeddings. + + Args: + tokens (Int[torch.Tensor, "batch pos"]): Input tokens. + past_kv_pos_offset (int, optional): The length of tokens in the past_kv_cache. Defaults to 0. + attention_mask (Int[torch.Tensor, "batch pos"], optional): The attention mask for padded tokens. + Defaults to None. + + Returns: + Float[torch.Tensor, "batch pos d_model"]: Absolute position embeddings. + """ + tokens_length = tokens.size(-1) + + if attention_mask is None: + pos_embed = self.W_pos[ + past_kv_pos_offset : tokens_length + past_kv_pos_offset, : + ] # [pos, d_model] + batch_pos_embed = einops.repeat( + pos_embed, "pos d_model -> batch pos d_model", batch=tokens.size(0) + ) + + else: + # Separated from the no padding case for computational efficiency + # (this code is a bit slower than the code above) + + offset_position_ids = get_offset_position_ids( + past_kv_pos_offset, attention_mask + ) + pos_embed = self.W_pos[offset_position_ids] # [batch, pos, d_model] + + # Set the position embeddings to 0 for pad tokens (this is an arbitrary choice) + padding_mask = ~attention_mask.bool() # [batch, tokens_length] + offset_padding_mask = padding_mask[ + :, past_kv_pos_offset : tokens_length + past_kv_pos_offset + ].unsqueeze( + -1 + ) # [batch, pos, 1] + batch_pos_embed = torch.where(offset_padding_mask, 0, pos_embed) + + return batch_pos_embed.clone() diff --git a/transformer_lens/components/rms_norm.py b/transformer_lens/components/rms_norm.py new file mode 100644 index 000000000..6f5909ca6 --- /dev/null +++ b/transformer_lens/components/rms_norm.py @@ -0,0 +1,52 @@ +"""Hooked Transformer Components. + +This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) +needed to create many different types of generative language models. They are used by +:class:`transformer_lens.HookedTransformer`. +""" +from jaxtyping import Float +import torch +import torch.nn as nn +from transformer_lens.hook_points import HookPoint +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from typing import Dict, Optional, Union + + + + +class RMSNorm(nn.Module): + def __init__( + self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None + ): + """ + RMSNorm - LayerNorm without the centering and bias (RMS = Root Mean Square) + + length (Optional[int]): If the dimension of the RMSNorm. If not provided, assumed to be d_model + """ + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.eps = self.cfg.eps + if length is None: + self.length = self.cfg.d_model + else: + self.length = length + + self.w = nn.Parameter(torch.ones(self.length, dtype=cfg.dtype)) + + # Adds a hook point for the normalisation scale factor + self.hook_scale = HookPoint() # [batch, pos, 1] + self.hook_normalized = HookPoint() # [batch, pos, length] + + def forward( + self, x: Float[torch.Tensor, "batch pos length"] + ) -> Float[torch.Tensor, "batch pos length"]: + if self.cfg.dtype not in [torch.float32, torch.float64]: + x = x.to(torch.float32) + + scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( + (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() + ) + x = self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length] + return x * self.w \ No newline at end of file diff --git a/transformer_lens/components/rms_norm_pre.py b/transformer_lens/components/rms_norm_pre.py new file mode 100644 index 000000000..977b91287 --- /dev/null +++ b/transformer_lens/components/rms_norm_pre.py @@ -0,0 +1,39 @@ +"""Hooked Transformer Components. + +This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) +needed to create many different types of generative language models. They are used by +:class:`transformer_lens.HookedTransformer`. +""" +from jaxtyping import Float +import torch +import torch.nn as nn +from transformer_lens.hook_points import HookPoint +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from typing import Dict, Union + + +class RMSNormPre(nn.Module): + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + """RMSNormPre - LayerNormPre without the centering and bias (RMS = Root Mean Square)""" + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.eps = self.cfg.eps + + # Adds a hook point for the normalisation scale factor + self.hook_scale = HookPoint() # [batch, pos] + self.hook_normalized = HookPoint() # [batch, pos, length] + + def forward( + self, x: Float[torch.Tensor, "batch pos length"] + ) -> Float[torch.Tensor, "batch pos length"]: + if self.cfg.dtype not in [torch.float32, torch.float64]: + x = x.to(torch.float32) + + scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( + (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() + ) + return self.hook_normalized(x / scale).to( + self.cfg.dtype + ) # [batch, pos, length] diff --git a/transformer_lens/components/token_typed_embed.py b/transformer_lens/components/token_typed_embed.py new file mode 100644 index 000000000..84ac6c8b5 --- /dev/null +++ b/transformer_lens/components/token_typed_embed.py @@ -0,0 +1,32 @@ +"""Hooked Transformer Components. + +This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) +needed to create many different types of generative language models. They are used by +:class:`transformer_lens.HookedTransformer`. +""" + +from jaxtyping import Int +import torch +import torch.nn as nn +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from typing import Dict, Union + + +class TokenTypeEmbed(nn.Module): + """ + The token-type embed is a binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length). + + See the BERT paper for more information: https://arxiv.org/pdf/1810.04805.pdf + """ + + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.W_token_type = nn.Parameter( + torch.empty(2, self.cfg.d_model, dtype=cfg.dtype) + ) + + def forward(self, token_type_ids: Int[torch.Tensor, "batch pos"]): + return self.W_token_type[token_type_ids, :] \ No newline at end of file diff --git a/transformer_lens/components/transformer_block.py b/transformer_lens/components/transformer_block.py new file mode 100644 index 000000000..9eafcb082 --- /dev/null +++ b/transformer_lens/components/transformer_block.py @@ -0,0 +1,189 @@ +"""Hooked Transformer Components. + +This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) +needed to create many different types of generative language models. They are used by +:class:`transformer_lens.HookedTransformer`. +""" +from .attention import Attention +from .gated_mlp import GatedMLP +from .layer_norm import LayerNorm +from .layer_norm_pre import LayerNormPre +from .mlp import MLP +from .rms_norm import RMSNorm +from .rms_norm_pre import RMSNormPre +import einops +from jaxtyping import Float, Int +import logging +import torch +import torch.nn as nn +from transformer_lens.hook_points import HookPoint +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry +from typing import Dict, Optional, Union + + +# Transformer Block +class TransformerBlock(nn.Module): + def __init__(self, cfg: Union[Dict, HookedTransformerConfig], block_index): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + if self.cfg.normalization_type == "LN": + self.ln1 = LayerNorm(cfg) + if not self.cfg.attn_only: + self.ln2 = LayerNorm(cfg) + elif self.cfg.normalization_type == "LNPre": + # We've folded in LayerNorm weights, so just need the center + scale parts + self.ln1 = LayerNormPre(cfg) + if not self.cfg.attn_only: + self.ln2 = LayerNormPre(cfg) + elif self.cfg.normalization_type == "RMS": + self.ln1 = RMSNorm(cfg) + if not self.cfg.attn_only: + self.ln2 = RMSNorm(cfg) + elif self.cfg.normalization_type == "RMSPre": + self.ln1 = RMSNormPre(cfg) + if not self.cfg.attn_only: + self.ln2 = RMSNormPre(cfg) + elif self.cfg.normalization_type is None: + self.ln1 = nn.Identity() + if not self.cfg.attn_only: + self.ln2 = nn.Identity() + else: + logging.warning( + f"Invalid normalization_type passed in {self.cfg.normalization_type}" + ) + + if not self.cfg.use_local_attn: + self.attn = Attention(cfg, "global", block_index) + else: + assert self.cfg.attn_types is not None + attn_type = self.cfg.attn_types[block_index] + self.attn = Attention(cfg, attn_type, block_index) + if not self.cfg.attn_only: + if self.cfg.gated_mlp: + self.mlp = GatedMLP(cfg) + else: + self.mlp = MLP(cfg) + + self.hook_attn_in = HookPoint() # [batch, pos, n_heads, d_model] + self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model] + self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model] + self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model] + self.hook_mlp_in = HookPoint() # [batch, pos, d_model] + + self.hook_attn_out = HookPoint() # [batch, pos, d_model] + self.hook_mlp_out = HookPoint() # [batch, pos, d_model] + + self.hook_resid_pre = HookPoint() # [batch, pos, d_model] + if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp: + self.hook_resid_mid = HookPoint() # [batch, pos, d_model] + self.hook_resid_post = HookPoint() # [batch, pos, d_model] + + def forward( + self, + resid_pre: Float[torch.Tensor, "batch pos d_model"], + shortformer_pos_embed: Optional[ + Float[torch.Tensor, "batch pos d_model"] + ] = None, + past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, + attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, + ) -> Float[torch.Tensor, "batch pos d_model"]: + """A single Transformer block. + + Args: + resid_pre (torch.Tensor): The residual stream - shape [batch, pos, d_model] + cache (HookedTransformerKeyValueCache): A cache of previous keys and values, used only when generating text. Defaults to None. + shortformer_pos_embed (torch.Tensor, optional): Only used for positional_embeddings_type == "shortformer". The positional embeddings. See HookedTransformerConfig for details. Defaults to None. + attention_mask (torch.Tensor, optional): The attention mask for padded tokens. Defaults to None. + + Returns: + _type_: _description_ + """ + resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model] + + def add_head_dimension( + tensor: Float[torch.Tensor, "batch pos d_model"], + clone_tensor=True, + # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry + ): + repeated_tensor = einops.repeat( + tensor, + "batch pos d_model -> batch pos n_heads d_model", + n_heads=self.cfg.n_heads, + ) + if clone_tensor: + return repeated_tensor.clone() + else: + return repeated_tensor + + if self.cfg.use_attn_in or self.cfg.use_split_qkv_input: + # We're adding a head dimension + attn_in = add_head_dimension(resid_pre, clone_tensor=False) + if shortformer_pos_embed is not None: + shortformer_pos_embed = add_head_dimension(shortformer_pos_embed) + else: + attn_in = resid_pre + + if self.cfg.use_attn_in: + attn_in = self.hook_attn_in(attn_in.clone()) + + if self.cfg.use_split_qkv_input: + query_input = self.hook_q_input(attn_in.clone()) + key_input = self.hook_k_input(attn_in.clone()) + value_input = self.hook_v_input(attn_in.clone()) + else: + query_input = attn_in + key_input = attn_in + value_input = attn_in + + attn_out = self.hook_attn_out( + # hook the residual stream states that are used to calculate the + # queries, keys and values, independently. + # Then take the layer norm of these inputs, and pass these to the attention module. + self.attn( + query_input=self.ln1(query_input) + + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed), + key_input=self.ln1(key_input) + + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed), + value_input=self.ln1(value_input), + past_kv_cache_entry=past_kv_cache_entry, + attention_mask=attention_mask, + ) + ) # [batch, pos, d_model] + if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp: + resid_mid = self.hook_resid_mid( + resid_pre + attn_out + ) # [batch, pos, d_model] + mlp_in = ( + resid_mid + if not self.cfg.use_hook_mlp_in + else self.hook_mlp_in(resid_mid.clone()) + ) + normalized_resid_mid = self.ln2(mlp_in) + mlp_out = self.hook_mlp_out( + self.mlp(normalized_resid_mid) + ) # [batch, pos, d_model] + resid_post = self.hook_resid_post( + resid_mid + mlp_out + ) # [batch, pos, d_model] + elif self.cfg.parallel_attn_mlp: + # Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used. + # In GPT-J, LN1 and LN2 are tied, in GPT-NeoX they aren't. + normalized_resid_pre_2 = self.ln2( + resid_pre + if not self.cfg.use_hook_mlp_in + else self.hook_mlp_in(resid_pre.clone()) + ) + mlp_out = self.hook_mlp_out( + self.mlp(normalized_resid_pre_2) + ) # [batch, pos, d_model] + resid_post = self.hook_resid_post( + resid_pre + attn_out + mlp_out + ) # [batch, pos, d_model] + else: + resid_post = self.hook_resid_post( + resid_pre + attn_out + ) # [batch, pos, d_model] + return resid_post diff --git a/transformer_lens/components/unembed.py b/transformer_lens/components/unembed.py new file mode 100644 index 000000000..6f0799019 --- /dev/null +++ b/transformer_lens/components/unembed.py @@ -0,0 +1,39 @@ +"""Hooked Transformer Unembed Component. + +This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) +needed to create many different types of generative language models. They are used by +:class:`transformer_lens.HookedTransformer`. +""" + +from fancy_einsum import einsum +from jaxtyping import Float +import torch +import torch.nn as nn +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from typing import Dict, Union + +class Unembed(nn.Module): + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + # Note that there's a separate variable for d_vocab_out and d_vocab (the input vocab size). For language tasks these are always the same, but for algorithmic tasks we may want them to be different. + self.W_U: Float[torch.Tensor, "d_model d_vocab_out"] = nn.Parameter( + torch.empty(self.cfg.d_model, self.cfg.d_vocab_out, dtype=cfg.dtype) + ) + self.b_U: Float[torch.Tensor, "d_vocab_out"] = nn.Parameter( + torch.zeros(self.cfg.d_vocab_out, dtype=cfg.dtype) + ) + + def forward( + self, residual: Float[torch.Tensor, "batch pos d_model"] + ) -> Float[torch.Tensor, "batch pos d_vocab_out"]: + return ( + einsum( + "batch pos d_model, d_model vocab -> batch pos vocab", + residual, + self.W_U, + ) + + self.b_U + ) From 444fb32adaa282e0c62ea3ce898dc9d8dfc05121 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Fri, 17 Nov 2023 01:11:49 +0100 Subject: [PATCH 02/73] reworked components a bit --- transformer_lens/components/BertEmbed.py | 62 ------------------- transformer_lens/components/__init__.py | 21 ++++--- transformer_lens/components/bert_block.py | 4 +- transformer_lens/components/bert_embed.py | 5 +- transformer_lens/components/bert_mlm_head.py | 2 +- transformer_lens/components/embed.py | 2 +- transformer_lens/components/gated_mlp.py | 3 +- transformer_lens/components/mlp.py | 3 +- transformer_lens/components/pos_embed.py | 1 - transformer_lens/components/rms_norm.py | 2 - .../components/token_typed_embed.py | 1 - .../components/transformer_block.py | 8 +-- transformer_lens/components/unembed.py | 1 - 13 files changed, 20 insertions(+), 95 deletions(-) delete mode 100644 transformer_lens/components/BertEmbed.py diff --git a/transformer_lens/components/BertEmbed.py b/transformer_lens/components/BertEmbed.py deleted file mode 100644 index dfdde0bd2..000000000 --- a/transformer_lens/components/BertEmbed.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Hooked Transformer Components. - -This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) -needed to create many different types of generative language models. They are used by -:class:`transformer_lens.HookedTransformer`. -""" -from .embed import Embed -from .layer_norm import LayerNorm -from .pos_embed import PosEmbed -from .token_typed_embed import TokenTypeEmbed -import einops -from jaxtyping import Int -import torch -import torch.nn as nn -from transformer_lens.hook_points import HookPoint -from transformer_lens.HookedTransformerConfig import HookedTransformerConfig -from typing import Dict, Optional, Union - - -class BertEmbed(nn.Module): - """ - Custom embedding layer for a BERT-like model. This module computes the sum of the token, positional and token-type embeddings and takes the layer norm of the result. - """ - - def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): - super().__init__() - if isinstance(cfg, Dict): - cfg = HookedTransformerConfig.from_dict(cfg) - self.cfg = cfg - self.embed = Embed(cfg) - self.pos_embed = PosEmbed(cfg) - self.token_type_embed = TokenTypeEmbed(cfg) - self.ln = LayerNorm(cfg) - - self.hook_embed = HookPoint() - self.hook_pos_embed = HookPoint() - self.hook_token_type_embed = HookPoint() - - def forward( - self, - input_ids: Int[torch.Tensor, "batch pos"], - token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, - ): - base_index_id = torch.arange(input_ids.shape[1], device=input_ids.device) - index_ids = einops.repeat( - base_index_id, "pos -> batch pos", batch=input_ids.shape[0] - ) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - word_embeddings_out = self.hook_embed(self.embed(input_ids)) - position_embeddings_out = self.hook_pos_embed(self.pos_embed(index_ids)) - token_type_embeddings_out = self.hook_token_type_embed( - self.token_type_embed(token_type_ids) - ) - - embeddings_out = ( - word_embeddings_out + position_embeddings_out + token_type_embeddings_out - ) - layer_norm_out = self.ln(embeddings_out) - return layer_norm_out - diff --git a/transformer_lens/components/__init__.py b/transformer_lens/components/__init__.py index 14bb14085..b6b972839 100644 --- a/transformer_lens/components/__init__.py +++ b/transformer_lens/components/__init__.py @@ -1,15 +1,20 @@ +# Independent classes from .attention import Attention -from .bert_block import BertBlock -from .bert_embed import BertEmbed -from .bert_mlm_head import BertMLMHead -from .embed import Embed -from .gated_mlp import GatedMLP from .layer_norm import LayerNorm from .layer_norm_pre import LayerNormPre -from .mlp import MLP from .pos_embed import PosEmbed from .rms_norm import RMSNorm from .rms_norm_pre import RMSNormPre from .token_typed_embed import TokenTypeEmbed -from .transformer_block import TransformerBlock -from .unembed import Unembed \ No newline at end of file +from .unembed import Unembed + +# Only dependent on independent modules +from .bert_mlm_head import BertMLMHead +from .embed import Embed +from .gated_mlp import GatedMLP +from .mlp import MLP + +# Interdependent modules +from .bert_block import BertBlock +from .bert_embed import BertEmbed +from .transformer_block import TransformerBlock \ No newline at end of file diff --git a/transformer_lens/components/bert_block.py b/transformer_lens/components/bert_block.py index 0a572b86a..8a77d6763 100644 --- a/transformer_lens/components/bert_block.py +++ b/transformer_lens/components/bert_block.py @@ -4,13 +4,11 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ -from .attention import Attention -from .layer_norm import LayerNorm -from .mlp import MLP import einops from jaxtyping import Float import torch import torch.nn as nn +from transformer_lens.components import Attention, LayerNorm, MLP from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from typing import Optional diff --git a/transformer_lens/components/bert_embed.py b/transformer_lens/components/bert_embed.py index dfdde0bd2..63583079f 100644 --- a/transformer_lens/components/bert_embed.py +++ b/transformer_lens/components/bert_embed.py @@ -4,14 +4,11 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ -from .embed import Embed -from .layer_norm import LayerNorm -from .pos_embed import PosEmbed -from .token_typed_embed import TokenTypeEmbed import einops from jaxtyping import Int import torch import torch.nn as nn +from transformer_lens.components import Embed, LayerNorm, PosEmbed, TokenTypeEmbed from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from typing import Dict, Optional, Union diff --git a/transformer_lens/components/bert_mlm_head.py b/transformer_lens/components/bert_mlm_head.py index 315447906..56f2b496e 100644 --- a/transformer_lens/components/bert_mlm_head.py +++ b/transformer_lens/components/bert_mlm_head.py @@ -4,11 +4,11 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ -from .layer_norm import LayerNorm from fancy_einsum import einsum from jaxtyping import Float import torch import torch.nn as nn +from transformer_lens.components import LayerNorm from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from typing import Dict, Union diff --git a/transformer_lens/components/embed.py b/transformer_lens/components/embed.py index 62eb827b2..946b79776 100644 --- a/transformer_lens/components/embed.py +++ b/transformer_lens/components/embed.py @@ -4,10 +4,10 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ -from .layer_norm import LayerNorm from jaxtyping import Float, Int import torch import torch.nn as nn +from transformer_lens.components import LayerNorm from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from typing import Dict, Union diff --git a/transformer_lens/components/gated_mlp.py b/transformer_lens/components/gated_mlp.py index c40f632eb..e9e0fc213 100644 --- a/transformer_lens/components/gated_mlp.py +++ b/transformer_lens/components/gated_mlp.py @@ -4,13 +4,12 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ -from .layer_norm import LayerNorm -from .layer_norm_pre import LayerNormPre from fancy_einsum import einsum from jaxtyping import Float import torch import torch.nn as nn import torch.nn.functional as F +from transformer_lens.components import LayerNorm, LayerNormPre from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.utils import gelu_fast, gelu_new, solu diff --git a/transformer_lens/components/mlp.py b/transformer_lens/components/mlp.py index d57fff224..b06ff4ce0 100644 --- a/transformer_lens/components/mlp.py +++ b/transformer_lens/components/mlp.py @@ -4,13 +4,12 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ -from .layer_norm import LayerNorm -from .layer_norm_pre import LayerNormPre from fancy_einsum import einsum from jaxtyping import Float import torch import torch.nn as nn import torch.nn.functional as F +from transformer_lens.components import LayerNorm, LayerNormPre from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.utils import gelu_fast, gelu_new, solu diff --git a/transformer_lens/components/pos_embed.py b/transformer_lens/components/pos_embed.py index 62ce9992f..a0d11407f 100644 --- a/transformer_lens/components/pos_embed.py +++ b/transformer_lens/components/pos_embed.py @@ -4,7 +4,6 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ - import einops from jaxtyping import Float, Int import torch diff --git a/transformer_lens/components/rms_norm.py b/transformer_lens/components/rms_norm.py index 6f5909ca6..8e04887bd 100644 --- a/transformer_lens/components/rms_norm.py +++ b/transformer_lens/components/rms_norm.py @@ -12,8 +12,6 @@ from typing import Dict, Optional, Union - - class RMSNorm(nn.Module): def __init__( self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None diff --git a/transformer_lens/components/token_typed_embed.py b/transformer_lens/components/token_typed_embed.py index 84ac6c8b5..985ad6170 100644 --- a/transformer_lens/components/token_typed_embed.py +++ b/transformer_lens/components/token_typed_embed.py @@ -4,7 +4,6 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ - from jaxtyping import Int import torch import torch.nn as nn diff --git a/transformer_lens/components/transformer_block.py b/transformer_lens/components/transformer_block.py index 9eafcb082..98b145423 100644 --- a/transformer_lens/components/transformer_block.py +++ b/transformer_lens/components/transformer_block.py @@ -4,18 +4,12 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ -from .attention import Attention -from .gated_mlp import GatedMLP -from .layer_norm import LayerNorm -from .layer_norm_pre import LayerNormPre -from .mlp import MLP -from .rms_norm import RMSNorm -from .rms_norm_pre import RMSNormPre import einops from jaxtyping import Float, Int import logging import torch import torch.nn as nn +from transformer_lens.components import Attention, GatedMLP, LayerNorm, LayerNormPre, MLP, RMSNorm, RMSNormPre from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry diff --git a/transformer_lens/components/unembed.py b/transformer_lens/components/unembed.py index 6f0799019..de690c3f5 100644 --- a/transformer_lens/components/unembed.py +++ b/transformer_lens/components/unembed.py @@ -4,7 +4,6 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ - from fancy_einsum import einsum from jaxtyping import Float import torch From 68c0c15fa30569e77508741499b363a889c3df43 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Tue, 21 Nov 2023 23:40:23 +0100 Subject: [PATCH 03/73] fixed cached data setting --- transformer_lens/__init__.py | 1 - transformer_lens/components/attention.py | 9 ++++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/transformer_lens/__init__.py b/transformer_lens/__init__.py index e2fb1484b..affa9b69d 100644 --- a/transformer_lens/__init__.py +++ b/transformer_lens/__init__.py @@ -5,7 +5,6 @@ HookedTransformerKeyValueCache, HookedTransformerKeyValueCacheEntry, ) -from . import components from .HookedTransformerConfig import HookedTransformerConfig from .FactoredMatrix import FactoredMatrix from .ActivationCache import ActivationCache diff --git a/transformer_lens/components/attention.py b/transformer_lens/components/attention.py index 04c8eff4c..e91143cca 100644 --- a/transformer_lens/components/attention.py +++ b/transformer_lens/components/attention.py @@ -5,19 +5,21 @@ :class:`transformer_lens.HookedTransformer`. """ +from typing import Dict, Optional, Tuple, Union + import einops -from fancy_einsum import einsum -from jaxtyping import Float, Int import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from fancy_einsum import einsum +from jaxtyping import Float, Int + from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry from transformer_lens.utils import get_offset_position_ids -from typing import Dict, Optional, Tuple, Union # Attention @@ -39,6 +41,7 @@ def __init__( """ super().__init__() + self.cached_alibi = None self.cfg = HookedTransformerConfig.from_dict(cfg) if isinstance(cfg, Dict) else cfg self.W_Q = nn.Parameter( From 11105de2cc7ffa2bf0d8bada961c1feabf00cc6b Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 22 Nov 2023 00:20:44 +0100 Subject: [PATCH 04/73] ran format --- transformer_lens/components/bert_block.py | 8 +++++--- transformer_lens/components/bert_embed.py | 6 ++++-- transformer_lens/components/bert_mlm_head.py | 8 +++++--- transformer_lens/components/embed.py | 7 +++++-- transformer_lens/components/gated_mlp.py | 11 +++++------ transformer_lens/components/layer_norm.py | 6 ++++-- transformer_lens/components/layer_norm_pre.py | 6 ++++-- transformer_lens/components/mlp.py | 8 +++++--- transformer_lens/components/pos_embed.py | 7 +++++-- transformer_lens/components/rms_norm.py | 6 ++++-- transformer_lens/components/rms_norm_pre.py | 6 ++++-- .../components/token_typed_embed.py | 6 ++++-- .../components/transformer_block.py | 18 ++++++++++++++---- transformer_lens/components/unembed.py | 9 ++++++--- 14 files changed, 74 insertions(+), 38 deletions(-) diff --git a/transformer_lens/components/bert_block.py b/transformer_lens/components/bert_block.py index 8a77d6763..e531bcd91 100644 --- a/transformer_lens/components/bert_block.py +++ b/transformer_lens/components/bert_block.py @@ -4,14 +4,16 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ +from typing import Optional + import einops -from jaxtyping import Float import torch import torch.nn as nn -from transformer_lens.components import Attention, LayerNorm, MLP +from jaxtyping import Float + +from transformer_lens.components import MLP, Attention, LayerNorm from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig -from typing import Optional class BertBlock(nn.Module): diff --git a/transformer_lens/components/bert_embed.py b/transformer_lens/components/bert_embed.py index 63583079f..a9541d61f 100644 --- a/transformer_lens/components/bert_embed.py +++ b/transformer_lens/components/bert_embed.py @@ -4,14 +4,16 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ +from typing import Dict, Optional, Union + import einops -from jaxtyping import Int import torch import torch.nn as nn +from jaxtyping import Int + from transformer_lens.components import Embed, LayerNorm, PosEmbed, TokenTypeEmbed from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig -from typing import Dict, Optional, Union class BertEmbed(nn.Module): diff --git a/transformer_lens/components/bert_mlm_head.py b/transformer_lens/components/bert_mlm_head.py index 56f2b496e..745fc5d13 100644 --- a/transformer_lens/components/bert_mlm_head.py +++ b/transformer_lens/components/bert_mlm_head.py @@ -4,13 +4,15 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ -from fancy_einsum import einsum -from jaxtyping import Float +from typing import Dict, Union + import torch import torch.nn as nn +from fancy_einsum import einsum +from jaxtyping import Float + from transformer_lens.components import LayerNorm from transformer_lens.HookedTransformerConfig import HookedTransformerConfig -from typing import Dict, Union class BertMLMHead(nn.Module): diff --git a/transformer_lens/components/embed.py b/transformer_lens/components/embed.py index 946b79776..0d988207b 100644 --- a/transformer_lens/components/embed.py +++ b/transformer_lens/components/embed.py @@ -4,12 +4,15 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ -from jaxtyping import Float, Int +from typing import Dict, Union + import torch import torch.nn as nn +from jaxtyping import Float, Int + from transformer_lens.components import LayerNorm from transformer_lens.HookedTransformerConfig import HookedTransformerConfig -from typing import Dict, Union + # Embed & Unembed class Embed(nn.Module): diff --git a/transformer_lens/components/gated_mlp.py b/transformer_lens/components/gated_mlp.py index e9e0fc213..7b74272b4 100644 --- a/transformer_lens/components/gated_mlp.py +++ b/transformer_lens/components/gated_mlp.py @@ -4,19 +4,18 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ -from fancy_einsum import einsum -from jaxtyping import Float +from typing import Dict, Union + import torch import torch.nn as nn import torch.nn.functional as F +from fancy_einsum import einsum +from jaxtyping import Float + from transformer_lens.components import LayerNorm, LayerNormPre from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.utils import gelu_fast, gelu_new, solu -from typing import Dict, Union - - - # TODO diff --git a/transformer_lens/components/layer_norm.py b/transformer_lens/components/layer_norm.py index 83841d93e..8da3820f4 100644 --- a/transformer_lens/components/layer_norm.py +++ b/transformer_lens/components/layer_norm.py @@ -4,12 +4,14 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ -from jaxtyping import Float +from typing import Dict, Optional, Union + import torch import torch.nn as nn +from jaxtyping import Float + from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig -from typing import Dict, Optional, Union class LayerNorm(nn.Module): diff --git a/transformer_lens/components/layer_norm_pre.py b/transformer_lens/components/layer_norm_pre.py index 00c7eae2b..51c2e160e 100644 --- a/transformer_lens/components/layer_norm_pre.py +++ b/transformer_lens/components/layer_norm_pre.py @@ -4,12 +4,14 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ -from jaxtyping import Float +from typing import Dict, Union + import torch import torch.nn as nn +from jaxtyping import Float + from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig -from typing import Dict, Union # LayerNormPre diff --git a/transformer_lens/components/mlp.py b/transformer_lens/components/mlp.py index b06ff4ce0..a05ce0c47 100644 --- a/transformer_lens/components/mlp.py +++ b/transformer_lens/components/mlp.py @@ -4,16 +4,18 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ -from fancy_einsum import einsum -from jaxtyping import Float +from typing import Dict, Union + import torch import torch.nn as nn import torch.nn.functional as F +from fancy_einsum import einsum +from jaxtyping import Float + from transformer_lens.components import LayerNorm, LayerNormPre from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.utils import gelu_fast, gelu_new, solu -from typing import Dict, Union # MLP Layers diff --git a/transformer_lens/components/pos_embed.py b/transformer_lens/components/pos_embed.py index a0d11407f..6d4909a94 100644 --- a/transformer_lens/components/pos_embed.py +++ b/transformer_lens/components/pos_embed.py @@ -4,13 +4,16 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ +from typing import Dict, Optional, Union + import einops -from jaxtyping import Float, Int import torch import torch.nn as nn +from jaxtyping import Float, Int + from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.utils import get_offset_position_ids -from typing import Dict, Optional, Union + # Positional Embeddings class PosEmbed(nn.Module): diff --git a/transformer_lens/components/rms_norm.py b/transformer_lens/components/rms_norm.py index 8e04887bd..94ea58293 100644 --- a/transformer_lens/components/rms_norm.py +++ b/transformer_lens/components/rms_norm.py @@ -4,12 +4,14 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ -from jaxtyping import Float +from typing import Dict, Optional, Union + import torch import torch.nn as nn +from jaxtyping import Float + from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig -from typing import Dict, Optional, Union class RMSNorm(nn.Module): diff --git a/transformer_lens/components/rms_norm_pre.py b/transformer_lens/components/rms_norm_pre.py index 977b91287..d6b939363 100644 --- a/transformer_lens/components/rms_norm_pre.py +++ b/transformer_lens/components/rms_norm_pre.py @@ -4,12 +4,14 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ -from jaxtyping import Float +from typing import Dict, Union + import torch import torch.nn as nn +from jaxtyping import Float + from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig -from typing import Dict, Union class RMSNormPre(nn.Module): diff --git a/transformer_lens/components/token_typed_embed.py b/transformer_lens/components/token_typed_embed.py index 985ad6170..ffd93a249 100644 --- a/transformer_lens/components/token_typed_embed.py +++ b/transformer_lens/components/token_typed_embed.py @@ -4,11 +4,13 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ -from jaxtyping import Int +from typing import Dict, Union + import torch import torch.nn as nn +from jaxtyping import Int + from transformer_lens.HookedTransformerConfig import HookedTransformerConfig -from typing import Dict, Union class TokenTypeEmbed(nn.Module): diff --git a/transformer_lens/components/transformer_block.py b/transformer_lens/components/transformer_block.py index 98b145423..495fed86e 100644 --- a/transformer_lens/components/transformer_block.py +++ b/transformer_lens/components/transformer_block.py @@ -4,16 +4,26 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ -import einops -from jaxtyping import Float, Int import logging +from typing import Dict, Optional, Union + +import einops import torch import torch.nn as nn -from transformer_lens.components import Attention, GatedMLP, LayerNorm, LayerNormPre, MLP, RMSNorm, RMSNormPre +from jaxtyping import Float, Int + +from transformer_lens.components import ( + MLP, + Attention, + GatedMLP, + LayerNorm, + LayerNormPre, + RMSNorm, + RMSNormPre, +) from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry -from typing import Dict, Optional, Union # Transformer Block diff --git a/transformer_lens/components/unembed.py b/transformer_lens/components/unembed.py index de690c3f5..66ff868f2 100644 --- a/transformer_lens/components/unembed.py +++ b/transformer_lens/components/unembed.py @@ -4,12 +4,15 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ -from fancy_einsum import einsum -from jaxtyping import Float +from typing import Dict, Union + import torch import torch.nn as nn +from fancy_einsum import einsum +from jaxtyping import Float + from transformer_lens.HookedTransformerConfig import HookedTransformerConfig -from typing import Dict, Union + class Unembed(nn.Module): def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): From 7ba755f784326f10b55a1a3d3b02011b1c82176a Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 22 Nov 2023 00:31:37 +0100 Subject: [PATCH 05/73] reformatted components section --- transformer_lens/components/__init__.py | 2 +- transformer_lens/components/attention.py | 25 +++++++++++-------- transformer_lens/components/bert_embed.py | 1 - transformer_lens/components/embed.py | 2 +- transformer_lens/components/layer_norm.py | 2 +- transformer_lens/components/layer_norm_pre.py | 2 +- transformer_lens/components/mlp.py | 6 +++-- transformer_lens/components/rms_norm.py | 2 +- .../components/token_typed_embed.py | 2 +- 9 files changed, 24 insertions(+), 20 deletions(-) diff --git a/transformer_lens/components/__init__.py b/transformer_lens/components/__init__.py index b6b972839..07768a3e3 100644 --- a/transformer_lens/components/__init__.py +++ b/transformer_lens/components/__init__.py @@ -17,4 +17,4 @@ # Interdependent modules from .bert_block import BertBlock from .bert_embed import BertEmbed -from .transformer_block import TransformerBlock \ No newline at end of file +from .transformer_block import TransformerBlock diff --git a/transformer_lens/components/attention.py b/transformer_lens/components/attention.py index e91143cca..6490717e3 100644 --- a/transformer_lens/components/attention.py +++ b/transformer_lens/components/attention.py @@ -40,9 +40,11 @@ def __init__( layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. """ super().__init__() - - self.cached_alibi = None - self.cfg = HookedTransformerConfig.from_dict(cfg) if isinstance(cfg, Dict) else cfg + + self.cached_alibi = None + self.cfg = ( + HookedTransformerConfig.from_dict(cfg) if isinstance(cfg, Dict) else cfg + ) self.W_Q = nn.Parameter( torch.empty( @@ -242,7 +244,7 @@ def forward( query_ctx = attn_scores.size(-2) # The key context length is the number of positions in the past - this includes all positions in the cache key_ctx = attn_scores.size(-1) - + alibi = self.get_cached_alibi(key_ctx=key_ctx) attn_scores += alibi[ @@ -305,7 +307,6 @@ def forward( + self.b_O ) # [batch, pos, d_model] - def apply_causal_mask( self, attn_scores: Float[ @@ -547,14 +548,16 @@ def create_alibi_bias( alibi_bias = torch.einsum("ij,k->kij", slope, multipliers) return alibi_bias - - def get_cached_alibi(self, key_ctx: int) -> Float[torch.Tensor, "head_idx query key"]: + + def get_cached_alibi( + self, key_ctx: int + ) -> Float[torch.Tensor, "head_idx query key"]: """Get A Cached ALiBi bias For Calculation. - + This function will check for if an instance of our ALiBi bias is currently set. If the ALiBi bias is not set or if our key context is greater than it's cached size, a new instance will be initiated. - + The cached ALiBi bias is then returned Returns: @@ -565,5 +568,5 @@ def get_cached_alibi(self, key_ctx: int) -> Float[torch.Tensor, "head_idx query self.cached_alibi = Attention.create_alibi_bias( self.cfg.n_heads, key_ctx, self.cfg.device ) - - return self.cached_alibi \ No newline at end of file + + return self.cached_alibi diff --git a/transformer_lens/components/bert_embed.py b/transformer_lens/components/bert_embed.py index a9541d61f..f7e0e7d00 100644 --- a/transformer_lens/components/bert_embed.py +++ b/transformer_lens/components/bert_embed.py @@ -58,4 +58,3 @@ def forward( ) layer_norm_out = self.ln(embeddings_out) return layer_norm_out - diff --git a/transformer_lens/components/embed.py b/transformer_lens/components/embed.py index 0d988207b..b13e92ac4 100644 --- a/transformer_lens/components/embed.py +++ b/transformer_lens/components/embed.py @@ -35,4 +35,4 @@ def forward( # B acts as a tensor of indices into the second dimension (so >=0 and Date: Wed, 22 Nov 2023 20:03:05 +0100 Subject: [PATCH 06/73] reverted attention changes --- transformer_lens/components/attention.py | 68 +++++++++--------------- 1 file changed, 24 insertions(+), 44 deletions(-) diff --git a/transformer_lens/components/attention.py b/transformer_lens/components/attention.py index 6490717e3..ef7dd8c8f 100644 --- a/transformer_lens/components/attention.py +++ b/transformer_lens/components/attention.py @@ -40,12 +40,9 @@ def __init__( layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. """ super().__init__() - - self.cached_alibi = None - self.cfg = ( - HookedTransformerConfig.from_dict(cfg) if isinstance(cfg, Dict) else cfg - ) - + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg self.W_Q = nn.Parameter( torch.empty( self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype @@ -77,29 +74,31 @@ def __init__( ) self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype)) + self.attn_type = attn_type # Create a max_ctx x max_ctx mask, with True iff that query position # can attend to that key position (query is first axis, key is second axis) causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool()) - - if attn_type == "global": + if self.attn_type == "global": # For global attention, this is a lower triangular matrix - key <= query self.register_buffer("mask", causal_mask) - elif attn_type == "local": + elif self.attn_type == "local": # For local, this is banded, query - window_size < key <= query assert isinstance(self.cfg.window_size, int) self.register_buffer( "mask", torch.triu(causal_mask, 1 - self.cfg.window_size) ) else: - raise ValueError(f"Invalid attention type: {attn_type}") + raise ValueError(f"Invalid attention type: {self.attn_type}") self.register_buffer("IGNORE", torch.tensor(-torch.inf)) self.layer_id = layer_id # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability? - self.attn_scale = np.sqrt(self.cfg.d_head) if self.cfg.use_attn_scale else 1.0 - + if self.cfg.use_attn_scale: + self.attn_scale = np.sqrt(self.cfg.d_head) + else: + self.attn_scale = 1.0 if self.cfg.scale_attn_by_inverse_layer_idx: self.attn_scale *= self.layer_id + 1 @@ -124,6 +123,10 @@ def __init__( ) self.register_buffer("rotary_sin", sin) self.register_buffer("rotary_cos", cos) + elif self.cfg.positional_embedding_type == "alibi": + # ALiBi bias wil be constructed on the first forward pass. + # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage. + self.alibi = None @property def OV(self) -> FactoredMatrix: @@ -179,7 +182,6 @@ def forward( qkv_einops_string = "batch pos head_index d_model" else: qkv_einops_string = "batch pos d_model" - q = self.hook_q( einsum( f"{qkv_einops_string}, head_index d_model d_head \ @@ -245,9 +247,13 @@ def forward( # The key context length is the number of positions in the past - this includes all positions in the cache key_ctx = attn_scores.size(-1) - alibi = self.get_cached_alibi(key_ctx=key_ctx) + # only recompute when necessary to increase efficiency. + if self.alibi is None or key_ctx > self.alibi.size(-1): + self.alibi = Attention.create_alibi_bias( + self.cfg.n_heads, key_ctx, self.cfg.device + ) - attn_scores += alibi[ + attn_scores += self.alibi[ :, :query_ctx, :key_ctx ] # [batch, head_index, query_pos, key_pos] @@ -256,7 +262,6 @@ def forward( attn_scores = self.apply_causal_mask( attn_scores, kv_cache_pos_offset, attention_mask ) # [batch, head_index, query_pos, key_pos] - if additive_attention_mask is not None: attn_scores += additive_attention_mask @@ -274,9 +279,8 @@ def forward( pattern, ) ) # [batch, pos, head_index, d_head] - if not self.cfg.use_attn_result: - return ( + out = ( ( einsum( "batch pos head_index d_head, \ @@ -300,12 +304,13 @@ def forward( self.W_O, ) ) # [batch, pos, head_index, d_model] - return ( + out = ( einops.reduce( result, "batch position index model->batch position model", "sum" ) + self.b_O ) # [batch, pos, d_model] + return out def apply_causal_mask( self, @@ -329,7 +334,6 @@ def apply_causal_mask( final_mask = self.mask[ None, None, -query_ctx_length:, -key_ctx_length: ] # [1, 1, pos, pos] - if attention_mask is not None: # Apply a causal mask to the attention scores considering the padding einsum_str = "batch head pos offset_pos, batch offset_pos -> batch head pos offset_pos" @@ -362,10 +366,8 @@ def calculate_sin_cos_rotary( freq = einops.repeat(freq, "d -> (2 d)") else: freq = einops.repeat(freq, "d -> (d 2)") - # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency angles = pos[:, None] / freq[None, :] - return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) def rotate_every_two( @@ -548,25 +550,3 @@ def create_alibi_bias( alibi_bias = torch.einsum("ij,k->kij", slope, multipliers) return alibi_bias - - def get_cached_alibi( - self, key_ctx: int - ) -> Float[torch.Tensor, "head_idx query key"]: - """Get A Cached ALiBi bias For Calculation. - - This function will check for if an instance of our ALiBi bias is currently set. - If the ALiBi bias is not set or if our key context is greater than it's cached size, a new - instance will be initiated. - - The cached ALiBi bias is then returned - - Returns: - The ALiBi bias that should be added to the attention scores before the softmax. - """ - # only recompute when necessary to increase efficiency. - if self.cached_alibi is None or key_ctx > self.cached_alibi.size(-1): - self.cached_alibi = Attention.create_alibi_bias( - self.cfg.n_heads, key_ctx, self.cfg.device - ) - - return self.cached_alibi From cecf93ed4204fc316875584a7024a4302035c4c2 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 22 Nov 2023 20:15:17 +0100 Subject: [PATCH 07/73] reverted layer norm changes --- transformer_lens/components/layer_norm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_lens/components/layer_norm.py b/transformer_lens/components/layer_norm.py index 42bfc10b5..35118d37c 100644 --- a/transformer_lens/components/layer_norm.py +++ b/transformer_lens/components/layer_norm.py @@ -28,7 +28,10 @@ def __init__( cfg = HookedTransformerConfig.from_dict(cfg) self.cfg = cfg self.eps = self.cfg.eps - self.length = self.cfg.d_model if length is None else length + if length is None: + self.length = self.cfg.d_model + else: + self.length = length self.w = nn.Parameter(torch.ones(self.length, dtype=cfg.dtype)) self.b = nn.Parameter(torch.zeros(self.length, dtype=cfg.dtype)) From 9d8e91a244247be411050c3dc8c1c0c02f092abd Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 22 Nov 2023 23:05:01 +0100 Subject: [PATCH 08/73] reverted mlp changes --- transformer_lens/components/mlp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_lens/components/mlp.py b/transformer_lens/components/mlp.py index 80bdf501e..6cbc811ec 100644 --- a/transformer_lens/components/mlp.py +++ b/transformer_lens/components/mlp.py @@ -68,9 +68,7 @@ def forward( + self.b_in ) # [batch, pos, d_mlp] if not self.cfg.act_fn.endswith("_ln"): - post_act = self.hook_post( - self.act_fn(pre_act) - ) # [batch, pos, d_mlp] TODO segmentation fault + post_act = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp] else: mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] post_act = self.hook_post(self.ln(mid_act)) From b9a9ba5fac264c4420b031eb4498413c92b10216 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 22 Nov 2023 23:14:46 +0100 Subject: [PATCH 09/73] Revert "reverted attention changes" This reverts commit 91a5712194563f35766d8946facd4f53ba2a2f84. --- transformer_lens/components/attention.py | 68 +++++++++++++++--------- 1 file changed, 44 insertions(+), 24 deletions(-) diff --git a/transformer_lens/components/attention.py b/transformer_lens/components/attention.py index ef7dd8c8f..6490717e3 100644 --- a/transformer_lens/components/attention.py +++ b/transformer_lens/components/attention.py @@ -40,9 +40,12 @@ def __init__( layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. """ super().__init__() - if isinstance(cfg, Dict): - cfg = HookedTransformerConfig.from_dict(cfg) - self.cfg = cfg + + self.cached_alibi = None + self.cfg = ( + HookedTransformerConfig.from_dict(cfg) if isinstance(cfg, Dict) else cfg + ) + self.W_Q = nn.Parameter( torch.empty( self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype @@ -74,31 +77,29 @@ def __init__( ) self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype)) - self.attn_type = attn_type # Create a max_ctx x max_ctx mask, with True iff that query position # can attend to that key position (query is first axis, key is second axis) causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool()) - if self.attn_type == "global": + + if attn_type == "global": # For global attention, this is a lower triangular matrix - key <= query self.register_buffer("mask", causal_mask) - elif self.attn_type == "local": + elif attn_type == "local": # For local, this is banded, query - window_size < key <= query assert isinstance(self.cfg.window_size, int) self.register_buffer( "mask", torch.triu(causal_mask, 1 - self.cfg.window_size) ) else: - raise ValueError(f"Invalid attention type: {self.attn_type}") + raise ValueError(f"Invalid attention type: {attn_type}") self.register_buffer("IGNORE", torch.tensor(-torch.inf)) self.layer_id = layer_id # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability? - if self.cfg.use_attn_scale: - self.attn_scale = np.sqrt(self.cfg.d_head) - else: - self.attn_scale = 1.0 + self.attn_scale = np.sqrt(self.cfg.d_head) if self.cfg.use_attn_scale else 1.0 + if self.cfg.scale_attn_by_inverse_layer_idx: self.attn_scale *= self.layer_id + 1 @@ -123,10 +124,6 @@ def __init__( ) self.register_buffer("rotary_sin", sin) self.register_buffer("rotary_cos", cos) - elif self.cfg.positional_embedding_type == "alibi": - # ALiBi bias wil be constructed on the first forward pass. - # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage. - self.alibi = None @property def OV(self) -> FactoredMatrix: @@ -182,6 +179,7 @@ def forward( qkv_einops_string = "batch pos head_index d_model" else: qkv_einops_string = "batch pos d_model" + q = self.hook_q( einsum( f"{qkv_einops_string}, head_index d_model d_head \ @@ -247,13 +245,9 @@ def forward( # The key context length is the number of positions in the past - this includes all positions in the cache key_ctx = attn_scores.size(-1) - # only recompute when necessary to increase efficiency. - if self.alibi is None or key_ctx > self.alibi.size(-1): - self.alibi = Attention.create_alibi_bias( - self.cfg.n_heads, key_ctx, self.cfg.device - ) + alibi = self.get_cached_alibi(key_ctx=key_ctx) - attn_scores += self.alibi[ + attn_scores += alibi[ :, :query_ctx, :key_ctx ] # [batch, head_index, query_pos, key_pos] @@ -262,6 +256,7 @@ def forward( attn_scores = self.apply_causal_mask( attn_scores, kv_cache_pos_offset, attention_mask ) # [batch, head_index, query_pos, key_pos] + if additive_attention_mask is not None: attn_scores += additive_attention_mask @@ -279,8 +274,9 @@ def forward( pattern, ) ) # [batch, pos, head_index, d_head] + if not self.cfg.use_attn_result: - out = ( + return ( ( einsum( "batch pos head_index d_head, \ @@ -304,13 +300,12 @@ def forward( self.W_O, ) ) # [batch, pos, head_index, d_model] - out = ( + return ( einops.reduce( result, "batch position index model->batch position model", "sum" ) + self.b_O ) # [batch, pos, d_model] - return out def apply_causal_mask( self, @@ -334,6 +329,7 @@ def apply_causal_mask( final_mask = self.mask[ None, None, -query_ctx_length:, -key_ctx_length: ] # [1, 1, pos, pos] + if attention_mask is not None: # Apply a causal mask to the attention scores considering the padding einsum_str = "batch head pos offset_pos, batch offset_pos -> batch head pos offset_pos" @@ -366,8 +362,10 @@ def calculate_sin_cos_rotary( freq = einops.repeat(freq, "d -> (2 d)") else: freq = einops.repeat(freq, "d -> (d 2)") + # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency angles = pos[:, None] / freq[None, :] + return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) def rotate_every_two( @@ -550,3 +548,25 @@ def create_alibi_bias( alibi_bias = torch.einsum("ij,k->kij", slope, multipliers) return alibi_bias + + def get_cached_alibi( + self, key_ctx: int + ) -> Float[torch.Tensor, "head_idx query key"]: + """Get A Cached ALiBi bias For Calculation. + + This function will check for if an instance of our ALiBi bias is currently set. + If the ALiBi bias is not set or if our key context is greater than it's cached size, a new + instance will be initiated. + + The cached ALiBi bias is then returned + + Returns: + The ALiBi bias that should be added to the attention scores before the softmax. + """ + # only recompute when necessary to increase efficiency. + if self.cached_alibi is None or key_ctx > self.cached_alibi.size(-1): + self.cached_alibi = Attention.create_alibi_bias( + self.cfg.n_heads, key_ctx, self.cfg.device + ) + + return self.cached_alibi From da0acf3dd493acfa24b1bffd3567b466d3444545 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 22 Nov 2023 23:15:03 +0100 Subject: [PATCH 10/73] Revert "reverted layer norm changes" This reverts commit cecf93ed4204fc316875584a7024a4302035c4c2. --- transformer_lens/components/layer_norm.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/transformer_lens/components/layer_norm.py b/transformer_lens/components/layer_norm.py index 35118d37c..42bfc10b5 100644 --- a/transformer_lens/components/layer_norm.py +++ b/transformer_lens/components/layer_norm.py @@ -28,10 +28,7 @@ def __init__( cfg = HookedTransformerConfig.from_dict(cfg) self.cfg = cfg self.eps = self.cfg.eps - if length is None: - self.length = self.cfg.d_model - else: - self.length = length + self.length = self.cfg.d_model if length is None else length self.w = nn.Parameter(torch.ones(self.length, dtype=cfg.dtype)) self.b = nn.Parameter(torch.zeros(self.length, dtype=cfg.dtype)) From 10d609e2d3c671308791efc4edacbef098f70c06 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 22 Nov 2023 23:15:13 +0100 Subject: [PATCH 11/73] Revert "reverted mlp changes" This reverts commit 9d8e91a244247be411050c3dc8c1c0c02f092abd. --- transformer_lens/components/mlp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_lens/components/mlp.py b/transformer_lens/components/mlp.py index 6cbc811ec..80bdf501e 100644 --- a/transformer_lens/components/mlp.py +++ b/transformer_lens/components/mlp.py @@ -68,7 +68,9 @@ def forward( + self.b_in ) # [batch, pos, d_mlp] if not self.cfg.act_fn.endswith("_ln"): - post_act = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp] + post_act = self.hook_post( + self.act_fn(pre_act) + ) # [batch, pos, d_mlp] TODO segmentation fault else: mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] post_act = self.hook_post(self.ln(mid_act)) From ef4518b3e848d799409c8b7045ec7206dc68bf23 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 22 Nov 2023 23:20:16 +0100 Subject: [PATCH 12/73] removed some model tests --- tests/acceptance/__init__.py | 0 tests/acceptance/test_hooked_transformer.py | 2 +- tests/manual_checks/__init__.py | 0 tests/unit/__init__.py | 0 tests/unit/factored_matrix/__init__.py | 0 5 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 tests/acceptance/__init__.py create mode 100644 tests/manual_checks/__init__.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/factored_matrix/__init__.py diff --git a/tests/acceptance/__init__.py b/tests/acceptance/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index 545e39b3b..e7b08b7c1 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -206,7 +206,7 @@ def check_performance(tl_model, hf_model, margin): def check_dtype(dtype, margin, no_processing=False): """Check the loading and inferences for different dtypes.""" - for model_path in ["gpt2", "roneneldan/TinyStories-33M", "EleutherAI/pythia-70m"]: + for model_path in ["gpt2"]: if no_processing: # For low precision, the processing is not advised. model = HookedTransformer.from_pretrained_no_processing( diff --git a/tests/manual_checks/__init__.py b/tests/manual_checks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/factored_matrix/__init__.py b/tests/unit/factored_matrix/__init__.py new file mode 100644 index 000000000..e69de29bb From 6cfb08f8e07246b25b315f1992f53d319e008d95 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 22 Nov 2023 23:20:26 +0100 Subject: [PATCH 13/73] Revert "removed some model tests" This reverts commit ef4518b3e848d799409c8b7045ec7206dc68bf23. --- tests/acceptance/__init__.py | 0 tests/acceptance/test_hooked_transformer.py | 2 +- tests/manual_checks/__init__.py | 0 tests/unit/__init__.py | 0 tests/unit/factored_matrix/__init__.py | 0 5 files changed, 1 insertion(+), 1 deletion(-) delete mode 100644 tests/acceptance/__init__.py delete mode 100644 tests/manual_checks/__init__.py delete mode 100644 tests/unit/__init__.py delete mode 100644 tests/unit/factored_matrix/__init__.py diff --git a/tests/acceptance/__init__.py b/tests/acceptance/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index e7b08b7c1..545e39b3b 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -206,7 +206,7 @@ def check_performance(tl_model, hf_model, margin): def check_dtype(dtype, margin, no_processing=False): """Check the loading and inferences for different dtypes.""" - for model_path in ["gpt2"]: + for model_path in ["gpt2", "roneneldan/TinyStories-33M", "EleutherAI/pythia-70m"]: if no_processing: # For low precision, the processing is not advised. model = HookedTransformer.from_pretrained_no_processing( diff --git a/tests/manual_checks/__init__.py b/tests/manual_checks/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/factored_matrix/__init__.py b/tests/unit/factored_matrix/__init__.py deleted file mode 100644 index e69de29bb..000000000 From 8832b7d1dbf2ee85035ba3928d6899996a4c34b8 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 22 Nov 2023 23:21:00 +0100 Subject: [PATCH 14/73] removed model tests --- tests/acceptance/test_hooked_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index 545e39b3b..e7b08b7c1 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -206,7 +206,7 @@ def check_performance(tl_model, hf_model, margin): def check_dtype(dtype, margin, no_processing=False): """Check the loading and inferences for different dtypes.""" - for model_path in ["gpt2", "roneneldan/TinyStories-33M", "EleutherAI/pythia-70m"]: + for model_path in ["gpt2"]: if no_processing: # For low precision, the processing is not advised. model = HookedTransformer.from_pretrained_no_processing( From 502fe652f676bc4e6246c57d6bcf8abff8d4a50e Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 22 Nov 2023 23:39:08 +0100 Subject: [PATCH 15/73] added model back --- tests/acceptance/test_hooked_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index e7b08b7c1..fa0ffe634 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -206,7 +206,7 @@ def check_performance(tl_model, hf_model, margin): def check_dtype(dtype, margin, no_processing=False): """Check the loading and inferences for different dtypes.""" - for model_path in ["gpt2"]: + for model_path in ["gpt2", "roneneldan/TinyStories-33M"]: if no_processing: # For low precision, the processing is not advised. model = HookedTransformer.from_pretrained_no_processing( From beb014e7051ce08f545b427aa58c484c581525e4 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Thu, 23 Nov 2023 00:18:39 +0100 Subject: [PATCH 16/73] lowered accuracy --- tests/acceptance/test_hooked_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index fa0ffe634..e60639e57 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -206,7 +206,7 @@ def check_performance(tl_model, hf_model, margin): def check_dtype(dtype, margin, no_processing=False): """Check the loading and inferences for different dtypes.""" - for model_path in ["gpt2", "roneneldan/TinyStories-33M"]: + for model_path in ["gpt2", "roneneldan/TinyStories-33M", "EleutherAI/pythia-70m"]: if no_processing: # For low precision, the processing is not advised. model = HookedTransformer.from_pretrained_no_processing( @@ -235,7 +235,7 @@ def check_dtype(dtype, margin, no_processing=False): @pytest.mark.parametrize("dtype", [torch.float64, torch.float32]) def test_dtypes(dtype): - check_dtype(dtype, margin=5e-5) + check_dtype(dtype, margin=5e-4) @pytest.mark.skipif( From a7ca7ea36224e83579f6e44c6cc965e0c9fa6563 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Thu, 23 Nov 2023 00:34:16 +0100 Subject: [PATCH 17/73] Revert "lowered accuracy" This reverts commit beb014e7051ce08f545b427aa58c484c581525e4. --- tests/acceptance/test_hooked_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index e60639e57..fa0ffe634 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -206,7 +206,7 @@ def check_performance(tl_model, hf_model, margin): def check_dtype(dtype, margin, no_processing=False): """Check the loading and inferences for different dtypes.""" - for model_path in ["gpt2", "roneneldan/TinyStories-33M", "EleutherAI/pythia-70m"]: + for model_path in ["gpt2", "roneneldan/TinyStories-33M"]: if no_processing: # For low precision, the processing is not advised. model = HookedTransformer.from_pretrained_no_processing( @@ -235,7 +235,7 @@ def check_dtype(dtype, margin, no_processing=False): @pytest.mark.parametrize("dtype", [torch.float64, torch.float32]) def test_dtypes(dtype): - check_dtype(dtype, margin=5e-4) + check_dtype(dtype, margin=5e-5) @pytest.mark.skipif( From fbf03a6ae175146b3ac050fdf6619f6af9262692 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Thu, 23 Nov 2023 00:59:59 +0100 Subject: [PATCH 18/73] added model back to test loop --- tests/acceptance/test_hooked_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index fa0ffe634..e60639e57 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -206,7 +206,7 @@ def check_performance(tl_model, hf_model, margin): def check_dtype(dtype, margin, no_processing=False): """Check the loading and inferences for different dtypes.""" - for model_path in ["gpt2", "roneneldan/TinyStories-33M"]: + for model_path in ["gpt2", "roneneldan/TinyStories-33M", "EleutherAI/pythia-70m"]: if no_processing: # For low precision, the processing is not advised. model = HookedTransformer.from_pretrained_no_processing( @@ -235,7 +235,7 @@ def check_dtype(dtype, margin, no_processing=False): @pytest.mark.parametrize("dtype", [torch.float64, torch.float32]) def test_dtypes(dtype): - check_dtype(dtype, margin=5e-5) + check_dtype(dtype, margin=5e-4) @pytest.mark.skipif( From 1c7a8bd990af0e29ee98c167a7ca75146d237c19 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Thu, 23 Nov 2023 01:20:23 +0100 Subject: [PATCH 19/73] reverted accuracy change --- tests/acceptance/test_hooked_transformer.py | 2 +- transformer_lens/components/__init__.py | 6 ++++++ transformer_lens/components/attention.py | 7 ++----- transformer_lens/components/bert_block.py | 6 ++---- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index e60639e57..545e39b3b 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -235,7 +235,7 @@ def check_dtype(dtype, margin, no_processing=False): @pytest.mark.parametrize("dtype", [torch.float64, torch.float32]) def test_dtypes(dtype): - check_dtype(dtype, margin=5e-4) + check_dtype(dtype, margin=5e-5) @pytest.mark.skipif( diff --git a/transformer_lens/components/__init__.py b/transformer_lens/components/__init__.py index 07768a3e3..e88e7d1e9 100644 --- a/transformer_lens/components/__init__.py +++ b/transformer_lens/components/__init__.py @@ -1,3 +1,9 @@ +"""Hooked Transformer Components. + +This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) +needed to create many different types of generative language models. They are used by +:class:`transformer_lens.HookedTransformer`. +""" # Independent classes from .attention import Attention from .layer_norm import LayerNorm diff --git a/transformer_lens/components/attention.py b/transformer_lens/components/attention.py index 6490717e3..115f70f17 100644 --- a/transformer_lens/components/attention.py +++ b/transformer_lens/components/attention.py @@ -1,10 +1,7 @@ -"""Hooked Transformer Components. +"""Hooked Transformer Attention Component. -This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) -needed to create many different types of generative language models. They are used by -:class:`transformer_lens.HookedTransformer`. +This module contains all the component :class:`Attention`. """ - from typing import Dict, Optional, Tuple, Union import einops diff --git a/transformer_lens/components/bert_block.py b/transformer_lens/components/bert_block.py index e531bcd91..0eee5dc33 100644 --- a/transformer_lens/components/bert_block.py +++ b/transformer_lens/components/bert_block.py @@ -1,8 +1,6 @@ -"""Hooked Transformer Components. +"""Hooked Transformer Bert Block Component. -This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) -needed to create many different types of generative language models. They are used by -:class:`transformer_lens.HookedTransformer`. +This module contains all the component :class:`BertBlock`. """ from typing import Optional From 42c195a0ee5cf785c79dddf87b3bf6e2b435ef17 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Thu, 23 Nov 2023 01:32:03 +0100 Subject: [PATCH 20/73] added proper headers --- transformer_lens/components/bert_embed.py | 6 ++---- transformer_lens/components/bert_mlm_head.py | 6 ++---- transformer_lens/components/embed.py | 4 +--- transformer_lens/components/gated_mlp.py | 6 ++---- transformer_lens/components/layer_norm.py | 6 ++---- transformer_lens/components/layer_norm_pre.py | 6 ++---- transformer_lens/components/mlp.py | 6 ++---- transformer_lens/components/pos_embed.py | 6 ++---- transformer_lens/components/rms_norm.py | 6 ++---- transformer_lens/components/rms_norm_pre.py | 6 ++---- transformer_lens/components/token_typed_embed.py | 6 ++---- transformer_lens/components/transformer_block.py | 6 ++---- transformer_lens/components/unembed.py | 4 +--- 13 files changed, 24 insertions(+), 50 deletions(-) diff --git a/transformer_lens/components/bert_embed.py b/transformer_lens/components/bert_embed.py index f7e0e7d00..3c8046270 100644 --- a/transformer_lens/components/bert_embed.py +++ b/transformer_lens/components/bert_embed.py @@ -1,8 +1,6 @@ -"""Hooked Transformer Components. +"""Hooked Transformer Bert Embed Component. -This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) -needed to create many different types of generative language models. They are used by -:class:`transformer_lens.HookedTransformer`. +This module contains all the component :class:`BertEmbed`. """ from typing import Dict, Optional, Union diff --git a/transformer_lens/components/bert_mlm_head.py b/transformer_lens/components/bert_mlm_head.py index 745fc5d13..878abec45 100644 --- a/transformer_lens/components/bert_mlm_head.py +++ b/transformer_lens/components/bert_mlm_head.py @@ -1,8 +1,6 @@ -"""Hooked Transformer Components. +"""Hooked Transformer Bert MLM Head Component. -This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) -needed to create many different types of generative language models. They are used by -:class:`transformer_lens.HookedTransformer`. +This module contains all the component :class:`BertMLMHead`. """ from typing import Dict, Union diff --git a/transformer_lens/components/embed.py b/transformer_lens/components/embed.py index b13e92ac4..da4c4fff1 100644 --- a/transformer_lens/components/embed.py +++ b/transformer_lens/components/embed.py @@ -1,8 +1,6 @@ """Hooked Transformer Embed Component. -This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) -needed to create many different types of generative language models. They are used by -:class:`transformer_lens.HookedTransformer`. +This module contains all the component :class:`BertMLMHead`. """ from typing import Dict, Union diff --git a/transformer_lens/components/gated_mlp.py b/transformer_lens/components/gated_mlp.py index 7b74272b4..490b862c1 100644 --- a/transformer_lens/components/gated_mlp.py +++ b/transformer_lens/components/gated_mlp.py @@ -1,8 +1,6 @@ -"""Hooked Transformer Components. +"""Hooked Transformer Gated MLP Component. -This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) -needed to create many different types of generative language models. They are used by -:class:`transformer_lens.HookedTransformer`. +This module contains all the component :class:`GatedMLP`. """ from typing import Dict, Union diff --git a/transformer_lens/components/layer_norm.py b/transformer_lens/components/layer_norm.py index 42bfc10b5..3d3d9bb8d 100644 --- a/transformer_lens/components/layer_norm.py +++ b/transformer_lens/components/layer_norm.py @@ -1,8 +1,6 @@ -"""Hooked Transformer Components. +"""Hooked Transformer Layer Norm Component. -This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) -needed to create many different types of generative language models. They are used by -:class:`transformer_lens.HookedTransformer`. +This module contains all the component :class:`LayerNorm`. """ from typing import Dict, Optional, Union diff --git a/transformer_lens/components/layer_norm_pre.py b/transformer_lens/components/layer_norm_pre.py index dbcd306f7..d318c0a0f 100644 --- a/transformer_lens/components/layer_norm_pre.py +++ b/transformer_lens/components/layer_norm_pre.py @@ -1,8 +1,6 @@ -"""Hooked Transformer Components. +"""Hooked Transformer Layer Norm Pre Component. -This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) -needed to create many different types of generative language models. They are used by -:class:`transformer_lens.HookedTransformer`. +This module contains all the component :class:`LayerNormPre`. """ from typing import Dict, Union diff --git a/transformer_lens/components/mlp.py b/transformer_lens/components/mlp.py index 80bdf501e..8625d4826 100644 --- a/transformer_lens/components/mlp.py +++ b/transformer_lens/components/mlp.py @@ -1,8 +1,6 @@ -"""Hooked Transformer Components. +"""Hooked Transformer MLP Component. -This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) -needed to create many different types of generative language models. They are used by -:class:`transformer_lens.HookedTransformer`. +This module contains all the component :class:`MLP`. """ from typing import Dict, Union diff --git a/transformer_lens/components/pos_embed.py b/transformer_lens/components/pos_embed.py index 6d4909a94..fa711c092 100644 --- a/transformer_lens/components/pos_embed.py +++ b/transformer_lens/components/pos_embed.py @@ -1,8 +1,6 @@ -"""Hooked Transformer Embed Component. +"""Hooked Transformer POS Embed Component. -This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) -needed to create many different types of generative language models. They are used by -:class:`transformer_lens.HookedTransformer`. +This module contains all the component :class:`PosEmbed`. """ from typing import Dict, Optional, Union diff --git a/transformer_lens/components/rms_norm.py b/transformer_lens/components/rms_norm.py index ec823b0ba..f9716b07d 100644 --- a/transformer_lens/components/rms_norm.py +++ b/transformer_lens/components/rms_norm.py @@ -1,8 +1,6 @@ -"""Hooked Transformer Components. +"""Hooked Transformer RMS Norm Component. -This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) -needed to create many different types of generative language models. They are used by -:class:`transformer_lens.HookedTransformer`. +This module contains all the component :class:`RMSNorm`. """ from typing import Dict, Optional, Union diff --git a/transformer_lens/components/rms_norm_pre.py b/transformer_lens/components/rms_norm_pre.py index d6b939363..e74d8cbe6 100644 --- a/transformer_lens/components/rms_norm_pre.py +++ b/transformer_lens/components/rms_norm_pre.py @@ -1,8 +1,6 @@ -"""Hooked Transformer Components. +"""Hooked Transformer RMS Norm Pre Component. -This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) -needed to create many different types of generative language models. They are used by -:class:`transformer_lens.HookedTransformer`. +This module contains all the component :class:`RMSNormPre`. """ from typing import Dict, Union diff --git a/transformer_lens/components/token_typed_embed.py b/transformer_lens/components/token_typed_embed.py index ad0005e21..5a9f3f10d 100644 --- a/transformer_lens/components/token_typed_embed.py +++ b/transformer_lens/components/token_typed_embed.py @@ -1,8 +1,6 @@ -"""Hooked Transformer Components. +"""Hooked Transformer Token Typed Embed Component. -This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) -needed to create many different types of generative language models. They are used by -:class:`transformer_lens.HookedTransformer`. +This module contains all the component :class:`TokenTypeEmbed`. """ from typing import Dict, Union diff --git a/transformer_lens/components/transformer_block.py b/transformer_lens/components/transformer_block.py index 495fed86e..af9b807bb 100644 --- a/transformer_lens/components/transformer_block.py +++ b/transformer_lens/components/transformer_block.py @@ -1,8 +1,6 @@ -"""Hooked Transformer Components. +"""Hooked Transformer Transformer Block Component. -This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) -needed to create many different types of generative language models. They are used by -:class:`transformer_lens.HookedTransformer`. +This module contains all the component :class:`TransformerBlock`. """ import logging from typing import Dict, Optional, Union diff --git a/transformer_lens/components/unembed.py b/transformer_lens/components/unembed.py index 66ff868f2..d1de2ea2f 100644 --- a/transformer_lens/components/unembed.py +++ b/transformer_lens/components/unembed.py @@ -1,8 +1,6 @@ """Hooked Transformer Unembed Component. -This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`) -needed to create many different types of generative language models. They are used by -:class:`transformer_lens.HookedTransformer`. +This module contains all the component :class:`Unembed`. """ from typing import Dict, Union From ce82675a8e89b6d5e6229a89620c843c794f3b04 Mon Sep 17 00:00:00 2001 From: Alan <41682961+alan-cooney@users.noreply.github.com> Date: Sun, 10 Dec 2023 10:37:59 -0300 Subject: [PATCH 21/73] Clean up project config (#463) Remove the pytorch versioning fix as this has been solved with the latest pytorch version. Also format with even better toml so that the pyproject is easier to read. --- .vscode/cspell.json | 5 +- .vscode/extensions.json | 5 +- .vscode/settings.json | 40 +- poetry.lock | 2055 +++++++++++++++++++------------------ pyproject.toml | 196 ++-- transformer_lens/utils.py | 2 + 6 files changed, 1166 insertions(+), 1137 deletions(-) diff --git a/.vscode/cspell.json b/.vscode/cspell.json index 738113240..ba26db277 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -2,6 +2,7 @@ "language": "en,en-GB", "words": [ "adrià", + "accum", "aengus", "alonso", "arange", @@ -9,7 +10,7 @@ "autodiff", "autoregressive", "barez", - "Beartype", + "beartype", "belrose", "bertsimas", "biderman", @@ -18,7 +19,7 @@ "checkpointed", "chughtai", "circuitsvis", - "Codespaces", + "codespaces", "colab", "collectstart", "colour", diff --git a/.vscode/extensions.json b/.vscode/extensions.json index 0eb6d4710..2d47fb339 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -17,8 +17,9 @@ "ms-toolsai.jupyter", "richie5um2.vscode-sort-json", "stkb.rewrap", + "streetsidesoftware.code-spell-checker-british-english", "streetsidesoftware.code-spell-checker", - "yzhang.markdown-all-in-one", - "streetsidesoftware.code-spell-checker-british-english" + "tamasfe.even-better-toml", + "yzhang.markdown-all-in-one" ] } \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 7092a6781..1c479871d 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,11 +1,36 @@ { - "editor.formatOnSave": true, - "editor.codeActionsOnSave": { - "source.organizeImports": true - }, "[python]": { "editor.defaultFormatter": "ms-python.black-formatter" }, + "[toml]": { + "editor.defaultFormatter": "tamasfe.even-better-toml" + }, + "editor.codeActionsOnSave": { + "source.organizeImports": true + }, + "editor.formatOnSave": true, + "evenBetterToml.formatter.allowedBlankLines": 1, + "evenBetterToml.formatter.arrayAutoCollapse": true, + "evenBetterToml.formatter.arrayAutoExpand": true, + "evenBetterToml.formatter.arrayTrailingComma": true, + "evenBetterToml.formatter.columnWidth": 100, + "evenBetterToml.formatter.compactArrays": true, + "evenBetterToml.formatter.compactEntries": true, + "evenBetterToml.formatter.compactInlineTables": true, + "evenBetterToml.formatter.indentEntries": true, + "evenBetterToml.formatter.indentString": " ", + "evenBetterToml.formatter.indentTables": true, + "evenBetterToml.formatter.inlineTableExpand": true, + "evenBetterToml.formatter.reorderArrays": true, + "evenBetterToml.formatter.reorderKeys": true, + "evenBetterToml.formatter.trailingNewline": true, + "evenBetterToml.schema.enabled": true, + "evenBetterToml.schema.links": true, + "evenBetterToml.syntax.semanticTokens": false, + "mypy-type-checker.importStrategy": "fromEnvironment", + "notebook.formatOnCellExecution": true, + "notebook.formatOnSave.enabled": true, + "pylint.importStrategy": "fromEnvironment", "python.testing.pytestArgs": [ "transformer_lens", ], @@ -13,11 +38,4 @@ "rewrap.autoWrap.enabled": true, "rewrap.reformat": true, "rewrap.wrappingColumn": 100, - "mypy-type-checker.importStrategy": "fromEnvironment", - "pylint.importStrategy": "fromEnvironment", - "notebook.formatOnCellExecution": true, - "notebook.formatOnSave.enabled": true, - "cSpell.words": [ - "accum" - ], } \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index bac8ff0c5..c95a1e5f2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,13 +1,14 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. [[package]] name = "accelerate" -version = "0.24.0" +version = "0.25.0" description = "Accelerate" optional = false python-versions = ">=3.8.0" files = [ - {file = "accelerate-0.24.0-py3-none-any.whl", hash = "sha256:04bb1483c90eacb3beb2687cb54950d8caf9a0b93432f6b2d42efebbb6c0491e"}, + {file = "accelerate-0.25.0-py3-none-any.whl", hash = "sha256:c7bb817eb974bba0ff3ea1ba0f24d55afb86d50e3d4fe98d6922dc69cf2ccff1"}, + {file = "accelerate-0.25.0.tar.gz", hash = "sha256:ecf55b0ab278a1dac8539dde0d276977aff04683f07ede73eaf02478538576a1"}, ] [package.dependencies] @@ -16,6 +17,7 @@ numpy = ">=1.17" packaging = ">=20.0" psutil = "*" pyyaml = "*" +safetensors = ">=0.3.1" torch = ">=1.10.0" [package.extras] @@ -25,116 +27,104 @@ rich = ["rich"] sagemaker = ["sagemaker"] test-dev = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "scikit-learn", "scipy", "timm", "tqdm", "transformers"] test-prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"] -test-trackers = ["comet-ml", "tensorboard", "wandb"] +test-trackers = ["comet-ml", "dvclive", "tensorboard", "wandb"] testing = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "tqdm", "transformers"] [[package]] name = "aiohttp" -version = "3.8.6" +version = "3.9.1" description = "Async http client/server framework (asyncio)" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "aiohttp-3.8.6-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:41d55fc043954cddbbd82503d9cc3f4814a40bcef30b3569bc7b5e34130718c1"}, - {file = "aiohttp-3.8.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1d84166673694841d8953f0a8d0c90e1087739d24632fe86b1a08819168b4566"}, - {file = "aiohttp-3.8.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:253bf92b744b3170eb4c4ca2fa58f9c4b87aeb1df42f71d4e78815e6e8b73c9e"}, - {file = "aiohttp-3.8.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3fd194939b1f764d6bb05490987bfe104287bbf51b8d862261ccf66f48fb4096"}, - {file = "aiohttp-3.8.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6c5f938d199a6fdbdc10bbb9447496561c3a9a565b43be564648d81e1102ac22"}, - {file = "aiohttp-3.8.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2817b2f66ca82ee699acd90e05c95e79bbf1dc986abb62b61ec8aaf851e81c93"}, - {file = "aiohttp-3.8.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fa375b3d34e71ccccf172cab401cd94a72de7a8cc01847a7b3386204093bb47"}, - {file = "aiohttp-3.8.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9de50a199b7710fa2904be5a4a9b51af587ab24c8e540a7243ab737b45844543"}, - {file = "aiohttp-3.8.6-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e1d8cb0b56b3587c5c01de3bf2f600f186da7e7b5f7353d1bf26a8ddca57f965"}, - {file = "aiohttp-3.8.6-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8e31e9db1bee8b4f407b77fd2507337a0a80665ad7b6c749d08df595d88f1cf5"}, - {file = "aiohttp-3.8.6-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7bc88fc494b1f0311d67f29fee6fd636606f4697e8cc793a2d912ac5b19aa38d"}, - {file = "aiohttp-3.8.6-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ec00c3305788e04bf6d29d42e504560e159ccaf0be30c09203b468a6c1ccd3b2"}, - {file = "aiohttp-3.8.6-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ad1407db8f2f49329729564f71685557157bfa42b48f4b93e53721a16eb813ed"}, - {file = "aiohttp-3.8.6-cp310-cp310-win32.whl", hash = "sha256:ccc360e87341ad47c777f5723f68adbb52b37ab450c8bc3ca9ca1f3e849e5fe2"}, - {file = "aiohttp-3.8.6-cp310-cp310-win_amd64.whl", hash = "sha256:93c15c8e48e5e7b89d5cb4613479d144fda8344e2d886cf694fd36db4cc86865"}, - {file = "aiohttp-3.8.6-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6e2f9cc8e5328f829f6e1fb74a0a3a939b14e67e80832975e01929e320386b34"}, - {file = "aiohttp-3.8.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e6a00ffcc173e765e200ceefb06399ba09c06db97f401f920513a10c803604ca"}, - {file = "aiohttp-3.8.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:41bdc2ba359032e36c0e9de5a3bd00d6fb7ea558a6ce6b70acedf0da86458321"}, - {file = "aiohttp-3.8.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:14cd52ccf40006c7a6cd34a0f8663734e5363fd981807173faf3a017e202fec9"}, - {file = "aiohttp-3.8.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2d5b785c792802e7b275c420d84f3397668e9d49ab1cb52bd916b3b3ffcf09ad"}, - {file = "aiohttp-3.8.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1bed815f3dc3d915c5c1e556c397c8667826fbc1b935d95b0ad680787896a358"}, - {file = "aiohttp-3.8.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96603a562b546632441926cd1293cfcb5b69f0b4159e6077f7c7dbdfb686af4d"}, - {file = "aiohttp-3.8.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d76e8b13161a202d14c9584590c4df4d068c9567c99506497bdd67eaedf36403"}, - {file = "aiohttp-3.8.6-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e3f1e3f1a1751bb62b4a1b7f4e435afcdade6c17a4fd9b9d43607cebd242924a"}, - {file = "aiohttp-3.8.6-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:76b36b3124f0223903609944a3c8bf28a599b2cc0ce0be60b45211c8e9be97f8"}, - {file = "aiohttp-3.8.6-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:a2ece4af1f3c967a4390c284797ab595a9f1bc1130ef8b01828915a05a6ae684"}, - {file = "aiohttp-3.8.6-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:16d330b3b9db87c3883e565340d292638a878236418b23cc8b9b11a054aaa887"}, - {file = "aiohttp-3.8.6-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:42c89579f82e49db436b69c938ab3e1559e5a4409eb8639eb4143989bc390f2f"}, - {file = "aiohttp-3.8.6-cp311-cp311-win32.whl", hash = "sha256:efd2fcf7e7b9d7ab16e6b7d54205beded0a9c8566cb30f09c1abe42b4e22bdcb"}, - {file = "aiohttp-3.8.6-cp311-cp311-win_amd64.whl", hash = "sha256:3b2ab182fc28e7a81f6c70bfbd829045d9480063f5ab06f6e601a3eddbbd49a0"}, - {file = "aiohttp-3.8.6-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:fdee8405931b0615220e5ddf8cd7edd8592c606a8e4ca2a00704883c396e4479"}, - {file = "aiohttp-3.8.6-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d25036d161c4fe2225d1abff2bd52c34ed0b1099f02c208cd34d8c05729882f0"}, - {file = "aiohttp-3.8.6-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5d791245a894be071d5ab04bbb4850534261a7d4fd363b094a7b9963e8cdbd31"}, - {file = "aiohttp-3.8.6-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0cccd1de239afa866e4ce5c789b3032442f19c261c7d8a01183fd956b1935349"}, - {file = "aiohttp-3.8.6-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f13f60d78224f0dace220d8ab4ef1dbc37115eeeab8c06804fec11bec2bbd07"}, - {file = "aiohttp-3.8.6-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8a9b5a0606faca4f6cc0d338359d6fa137104c337f489cd135bb7fbdbccb1e39"}, - {file = "aiohttp-3.8.6-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:13da35c9ceb847732bf5c6c5781dcf4780e14392e5d3b3c689f6d22f8e15ae31"}, - {file = "aiohttp-3.8.6-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:4d4cbe4ffa9d05f46a28252efc5941e0462792930caa370a6efaf491f412bc66"}, - {file = "aiohttp-3.8.6-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:229852e147f44da0241954fc6cb910ba074e597f06789c867cb7fb0621e0ba7a"}, - {file = "aiohttp-3.8.6-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:713103a8bdde61d13490adf47171a1039fd880113981e55401a0f7b42c37d071"}, - {file = "aiohttp-3.8.6-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:45ad816b2c8e3b60b510f30dbd37fe74fd4a772248a52bb021f6fd65dff809b6"}, - {file = "aiohttp-3.8.6-cp36-cp36m-win32.whl", hash = "sha256:2b8d4e166e600dcfbff51919c7a3789ff6ca8b3ecce16e1d9c96d95dd569eb4c"}, - {file = "aiohttp-3.8.6-cp36-cp36m-win_amd64.whl", hash = "sha256:0912ed87fee967940aacc5306d3aa8ba3a459fcd12add0b407081fbefc931e53"}, - {file = "aiohttp-3.8.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e2a988a0c673c2e12084f5e6ba3392d76c75ddb8ebc6c7e9ead68248101cd446"}, - {file = "aiohttp-3.8.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ebf3fd9f141700b510d4b190094db0ce37ac6361a6806c153c161dc6c041ccda"}, - {file = "aiohttp-3.8.6-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3161ce82ab85acd267c8f4b14aa226047a6bee1e4e6adb74b798bd42c6ae1f80"}, - {file = "aiohttp-3.8.6-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d95fc1bf33a9a81469aa760617b5971331cdd74370d1214f0b3109272c0e1e3c"}, - {file = "aiohttp-3.8.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c43ecfef7deaf0617cee936836518e7424ee12cb709883f2c9a1adda63cc460"}, - {file = "aiohttp-3.8.6-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca80e1b90a05a4f476547f904992ae81eda5c2c85c66ee4195bb8f9c5fb47f28"}, - {file = "aiohttp-3.8.6-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:90c72ebb7cb3a08a7f40061079817133f502a160561d0675b0a6adf231382c92"}, - {file = "aiohttp-3.8.6-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:bb54c54510e47a8c7c8e63454a6acc817519337b2b78606c4e840871a3e15349"}, - {file = "aiohttp-3.8.6-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:de6a1c9f6803b90e20869e6b99c2c18cef5cc691363954c93cb9adeb26d9f3ae"}, - {file = "aiohttp-3.8.6-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:a3628b6c7b880b181a3ae0a0683698513874df63783fd89de99b7b7539e3e8a8"}, - {file = "aiohttp-3.8.6-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:fc37e9aef10a696a5a4474802930079ccfc14d9f9c10b4662169671ff034b7df"}, - {file = "aiohttp-3.8.6-cp37-cp37m-win32.whl", hash = "sha256:f8ef51e459eb2ad8e7a66c1d6440c808485840ad55ecc3cafefadea47d1b1ba2"}, - {file = "aiohttp-3.8.6-cp37-cp37m-win_amd64.whl", hash = "sha256:b2fe42e523be344124c6c8ef32a011444e869dc5f883c591ed87f84339de5976"}, - {file = "aiohttp-3.8.6-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:9e2ee0ac5a1f5c7dd3197de309adfb99ac4617ff02b0603fd1e65b07dc772e4b"}, - {file = "aiohttp-3.8.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:01770d8c04bd8db568abb636c1fdd4f7140b284b8b3e0b4584f070180c1e5c62"}, - {file = "aiohttp-3.8.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3c68330a59506254b556b99a91857428cab98b2f84061260a67865f7f52899f5"}, - {file = "aiohttp-3.8.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89341b2c19fb5eac30c341133ae2cc3544d40d9b1892749cdd25892bbc6ac951"}, - {file = "aiohttp-3.8.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:71783b0b6455ac8f34b5ec99d83e686892c50498d5d00b8e56d47f41b38fbe04"}, - {file = "aiohttp-3.8.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f628dbf3c91e12f4d6c8b3f092069567d8eb17814aebba3d7d60c149391aee3a"}, - {file = "aiohttp-3.8.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b04691bc6601ef47c88f0255043df6f570ada1a9ebef99c34bd0b72866c217ae"}, - {file = "aiohttp-3.8.6-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ee912f7e78287516df155f69da575a0ba33b02dd7c1d6614dbc9463f43066e3"}, - {file = "aiohttp-3.8.6-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9c19b26acdd08dd239e0d3669a3dddafd600902e37881f13fbd8a53943079dbc"}, - {file = "aiohttp-3.8.6-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:99c5ac4ad492b4a19fc132306cd57075c28446ec2ed970973bbf036bcda1bcc6"}, - {file = "aiohttp-3.8.6-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:f0f03211fd14a6a0aed2997d4b1c013d49fb7b50eeb9ffdf5e51f23cfe2c77fa"}, - {file = "aiohttp-3.8.6-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:8d399dade330c53b4106160f75f55407e9ae7505263ea86f2ccca6bfcbdb4921"}, - {file = "aiohttp-3.8.6-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ec4fd86658c6a8964d75426517dc01cbf840bbf32d055ce64a9e63a40fd7b771"}, - {file = "aiohttp-3.8.6-cp38-cp38-win32.whl", hash = "sha256:33164093be11fcef3ce2571a0dccd9041c9a93fa3bde86569d7b03120d276c6f"}, - {file = "aiohttp-3.8.6-cp38-cp38-win_amd64.whl", hash = "sha256:bdf70bfe5a1414ba9afb9d49f0c912dc524cf60141102f3a11143ba3d291870f"}, - {file = "aiohttp-3.8.6-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d52d5dc7c6682b720280f9d9db41d36ebe4791622c842e258c9206232251ab2b"}, - {file = "aiohttp-3.8.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4ac39027011414dbd3d87f7edb31680e1f430834c8cef029f11c66dad0670aa5"}, - {file = "aiohttp-3.8.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3f5c7ce535a1d2429a634310e308fb7d718905487257060e5d4598e29dc17f0b"}, - {file = "aiohttp-3.8.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b30e963f9e0d52c28f284d554a9469af073030030cef8693106d918b2ca92f54"}, - {file = "aiohttp-3.8.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:918810ef188f84152af6b938254911055a72e0f935b5fbc4c1a4ed0b0584aed1"}, - {file = "aiohttp-3.8.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:002f23e6ea8d3dd8d149e569fd580c999232b5fbc601c48d55398fbc2e582e8c"}, - {file = "aiohttp-3.8.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4fcf3eabd3fd1a5e6092d1242295fa37d0354b2eb2077e6eb670accad78e40e1"}, - {file = "aiohttp-3.8.6-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:255ba9d6d5ff1a382bb9a578cd563605aa69bec845680e21c44afc2670607a95"}, - {file = "aiohttp-3.8.6-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d67f8baed00870aa390ea2590798766256f31dc5ed3ecc737debb6e97e2ede78"}, - {file = "aiohttp-3.8.6-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:86f20cee0f0a317c76573b627b954c412ea766d6ada1a9fcf1b805763ae7feeb"}, - {file = "aiohttp-3.8.6-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:39a312d0e991690ccc1a61f1e9e42daa519dcc34ad03eb6f826d94c1190190dd"}, - {file = "aiohttp-3.8.6-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:e827d48cf802de06d9c935088c2924e3c7e7533377d66b6f31ed175c1620e05e"}, - {file = "aiohttp-3.8.6-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:bd111d7fc5591ddf377a408ed9067045259ff2770f37e2d94e6478d0f3fc0c17"}, - {file = "aiohttp-3.8.6-cp39-cp39-win32.whl", hash = "sha256:caf486ac1e689dda3502567eb89ffe02876546599bbf915ec94b1fa424eeffd4"}, - {file = "aiohttp-3.8.6-cp39-cp39-win_amd64.whl", hash = "sha256:3f0e27e5b733803333bb2371249f41cf42bae8884863e8e8965ec69bebe53132"}, - {file = "aiohttp-3.8.6.tar.gz", hash = "sha256:b0cf2a4501bff9330a8a5248b4ce951851e415bdcce9dc158e76cfd55e15085c"}, + {file = "aiohttp-3.9.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e1f80197f8b0b846a8d5cf7b7ec6084493950d0882cc5537fb7b96a69e3c8590"}, + {file = "aiohttp-3.9.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c72444d17777865734aa1a4d167794c34b63e5883abb90356a0364a28904e6c0"}, + {file = "aiohttp-3.9.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9b05d5cbe9dafcdc733262c3a99ccf63d2f7ce02543620d2bd8db4d4f7a22f83"}, + {file = "aiohttp-3.9.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c4fa235d534b3547184831c624c0b7c1e262cd1de847d95085ec94c16fddcd5"}, + {file = "aiohttp-3.9.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:289ba9ae8e88d0ba16062ecf02dd730b34186ea3b1e7489046fc338bdc3361c4"}, + {file = "aiohttp-3.9.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bff7e2811814fa2271be95ab6e84c9436d027a0e59665de60edf44e529a42c1f"}, + {file = "aiohttp-3.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81b77f868814346662c96ab36b875d7814ebf82340d3284a31681085c051320f"}, + {file = "aiohttp-3.9.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b9c7426923bb7bd66d409da46c41e3fb40f5caf679da624439b9eba92043fa6"}, + {file = "aiohttp-3.9.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:8d44e7bf06b0c0a70a20f9100af9fcfd7f6d9d3913e37754c12d424179b4e48f"}, + {file = "aiohttp-3.9.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:22698f01ff5653fe66d16ffb7658f582a0ac084d7da1323e39fd9eab326a1f26"}, + {file = "aiohttp-3.9.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ca7ca5abfbfe8d39e653870fbe8d7710be7a857f8a8386fc9de1aae2e02ce7e4"}, + {file = "aiohttp-3.9.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:8d7f98fde213f74561be1d6d3fa353656197f75d4edfbb3d94c9eb9b0fc47f5d"}, + {file = "aiohttp-3.9.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5216b6082c624b55cfe79af5d538e499cd5f5b976820eac31951fb4325974501"}, + {file = "aiohttp-3.9.1-cp310-cp310-win32.whl", hash = "sha256:0e7ba7ff228c0d9a2cd66194e90f2bca6e0abca810b786901a569c0de082f489"}, + {file = "aiohttp-3.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:c7e939f1ae428a86e4abbb9a7c4732bf4706048818dfd979e5e2839ce0159f23"}, + {file = "aiohttp-3.9.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:df9cf74b9bc03d586fc53ba470828d7b77ce51b0582d1d0b5b2fb673c0baa32d"}, + {file = "aiohttp-3.9.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ecca113f19d5e74048c001934045a2b9368d77b0b17691d905af18bd1c21275e"}, + {file = "aiohttp-3.9.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8cef8710fb849d97c533f259103f09bac167a008d7131d7b2b0e3a33269185c0"}, + {file = "aiohttp-3.9.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bea94403a21eb94c93386d559bce297381609153e418a3ffc7d6bf772f59cc35"}, + {file = "aiohttp-3.9.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91c742ca59045dce7ba76cab6e223e41d2c70d79e82c284a96411f8645e2afff"}, + {file = "aiohttp-3.9.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6c93b7c2e52061f0925c3382d5cb8980e40f91c989563d3d32ca280069fd6a87"}, + {file = "aiohttp-3.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee2527134f95e106cc1653e9ac78846f3a2ec1004cf20ef4e02038035a74544d"}, + {file = "aiohttp-3.9.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:11ff168d752cb41e8492817e10fb4f85828f6a0142b9726a30c27c35a1835f01"}, + {file = "aiohttp-3.9.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b8c3a67eb87394386847d188996920f33b01b32155f0a94f36ca0e0c635bf3e3"}, + {file = "aiohttp-3.9.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c7b5d5d64e2a14e35a9240b33b89389e0035e6de8dbb7ffa50d10d8b65c57449"}, + {file = "aiohttp-3.9.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:69985d50a2b6f709412d944ffb2e97d0be154ea90600b7a921f95a87d6f108a2"}, + {file = "aiohttp-3.9.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:c9110c06eaaac7e1f5562caf481f18ccf8f6fdf4c3323feab28a93d34cc646bd"}, + {file = "aiohttp-3.9.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d737e69d193dac7296365a6dcb73bbbf53bb760ab25a3727716bbd42022e8d7a"}, + {file = "aiohttp-3.9.1-cp311-cp311-win32.whl", hash = "sha256:4ee8caa925aebc1e64e98432d78ea8de67b2272252b0a931d2ac3bd876ad5544"}, + {file = "aiohttp-3.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:a34086c5cc285be878622e0a6ab897a986a6e8bf5b67ecb377015f06ed316587"}, + {file = "aiohttp-3.9.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f800164276eec54e0af5c99feb9494c295118fc10a11b997bbb1348ba1a52065"}, + {file = "aiohttp-3.9.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:500f1c59906cd142d452074f3811614be04819a38ae2b3239a48b82649c08821"}, + {file = "aiohttp-3.9.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0b0a6a36ed7e164c6df1e18ee47afbd1990ce47cb428739d6c99aaabfaf1b3af"}, + {file = "aiohttp-3.9.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69da0f3ed3496808e8cbc5123a866c41c12c15baaaead96d256477edf168eb57"}, + {file = "aiohttp-3.9.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:176df045597e674fa950bf5ae536be85699e04cea68fa3a616cf75e413737eb5"}, + {file = "aiohttp-3.9.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b796b44111f0cab6bbf66214186e44734b5baab949cb5fb56154142a92989aeb"}, + {file = "aiohttp-3.9.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f27fdaadce22f2ef950fc10dcdf8048407c3b42b73779e48a4e76b3c35bca26c"}, + {file = "aiohttp-3.9.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bcb6532b9814ea7c5a6a3299747c49de30e84472fa72821b07f5a9818bce0f66"}, + {file = "aiohttp-3.9.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:54631fb69a6e44b2ba522f7c22a6fb2667a02fd97d636048478db2fd8c4e98fe"}, + {file = "aiohttp-3.9.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4b4c452d0190c5a820d3f5c0f3cd8a28ace48c54053e24da9d6041bf81113183"}, + {file = "aiohttp-3.9.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:cae4c0c2ca800c793cae07ef3d40794625471040a87e1ba392039639ad61ab5b"}, + {file = "aiohttp-3.9.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:565760d6812b8d78d416c3c7cfdf5362fbe0d0d25b82fed75d0d29e18d7fc30f"}, + {file = "aiohttp-3.9.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:54311eb54f3a0c45efb9ed0d0a8f43d1bc6060d773f6973efd90037a51cd0a3f"}, + {file = "aiohttp-3.9.1-cp312-cp312-win32.whl", hash = "sha256:85c3e3c9cb1d480e0b9a64c658cd66b3cfb8e721636ab8b0e746e2d79a7a9eed"}, + {file = "aiohttp-3.9.1-cp312-cp312-win_amd64.whl", hash = "sha256:11cb254e397a82efb1805d12561e80124928e04e9c4483587ce7390b3866d213"}, + {file = "aiohttp-3.9.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8a22a34bc594d9d24621091d1b91511001a7eea91d6652ea495ce06e27381f70"}, + {file = "aiohttp-3.9.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:598db66eaf2e04aa0c8900a63b0101fdc5e6b8a7ddd805c56d86efb54eb66672"}, + {file = "aiohttp-3.9.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2c9376e2b09895c8ca8b95362283365eb5c03bdc8428ade80a864160605715f1"}, + {file = "aiohttp-3.9.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41473de252e1797c2d2293804e389a6d6986ef37cbb4a25208de537ae32141dd"}, + {file = "aiohttp-3.9.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9c5857612c9813796960c00767645cb5da815af16dafb32d70c72a8390bbf690"}, + {file = "aiohttp-3.9.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ffcd828e37dc219a72c9012ec44ad2e7e3066bec6ff3aaa19e7d435dbf4032ca"}, + {file = "aiohttp-3.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:219a16763dc0294842188ac8a12262b5671817042b35d45e44fd0a697d8c8361"}, + {file = "aiohttp-3.9.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f694dc8a6a3112059258a725a4ebe9acac5fe62f11c77ac4dcf896edfa78ca28"}, + {file = "aiohttp-3.9.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:bcc0ea8d5b74a41b621ad4a13d96c36079c81628ccc0b30cfb1603e3dfa3a014"}, + {file = "aiohttp-3.9.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:90ec72d231169b4b8d6085be13023ece8fa9b1bb495e4398d847e25218e0f431"}, + {file = "aiohttp-3.9.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:cf2a0ac0615842b849f40c4d7f304986a242f1e68286dbf3bd7a835e4f83acfd"}, + {file = "aiohttp-3.9.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:0e49b08eafa4f5707ecfb321ab9592717a319e37938e301d462f79b4e860c32a"}, + {file = "aiohttp-3.9.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2c59e0076ea31c08553e868cec02d22191c086f00b44610f8ab7363a11a5d9d8"}, + {file = "aiohttp-3.9.1-cp38-cp38-win32.whl", hash = "sha256:4831df72b053b1eed31eb00a2e1aff6896fb4485301d4ccb208cac264b648db4"}, + {file = "aiohttp-3.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:3135713c5562731ee18f58d3ad1bf41e1d8883eb68b363f2ffde5b2ea4b84cc7"}, + {file = "aiohttp-3.9.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cfeadf42840c1e870dc2042a232a8748e75a36b52d78968cda6736de55582766"}, + {file = "aiohttp-3.9.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:70907533db712f7aa791effb38efa96f044ce3d4e850e2d7691abd759f4f0ae0"}, + {file = "aiohttp-3.9.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cdefe289681507187e375a5064c7599f52c40343a8701761c802c1853a504558"}, + {file = "aiohttp-3.9.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7481f581251bb5558ba9f635db70908819caa221fc79ee52a7f58392778c636"}, + {file = "aiohttp-3.9.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:49f0c1b3c2842556e5de35f122fc0f0b721334ceb6e78c3719693364d4af8499"}, + {file = "aiohttp-3.9.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0d406b01a9f5a7e232d1b0d161b40c05275ffbcbd772dc18c1d5a570961a1ca4"}, + {file = "aiohttp-3.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d8e4450e7fe24d86e86b23cc209e0023177b6d59502e33807b732d2deb6975f"}, + {file = "aiohttp-3.9.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c0266cd6f005e99f3f51e583012de2778e65af6b73860038b968a0a8888487a"}, + {file = "aiohttp-3.9.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab221850108a4a063c5b8a70f00dd7a1975e5a1713f87f4ab26a46e5feac5a0e"}, + {file = "aiohttp-3.9.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c88a15f272a0ad3d7773cf3a37cc7b7d077cbfc8e331675cf1346e849d97a4e5"}, + {file = "aiohttp-3.9.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:237533179d9747080bcaad4d02083ce295c0d2eab3e9e8ce103411a4312991a0"}, + {file = "aiohttp-3.9.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:02ab6006ec3c3463b528374c4cdce86434e7b89ad355e7bf29e2f16b46c7dd6f"}, + {file = "aiohttp-3.9.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04fa38875e53eb7e354ece1607b1d2fdee2d175ea4e4d745f6ec9f751fe20c7c"}, + {file = "aiohttp-3.9.1-cp39-cp39-win32.whl", hash = "sha256:82eefaf1a996060602f3cc1112d93ba8b201dbf5d8fd9611227de2003dddb3b7"}, + {file = "aiohttp-3.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:9b05d33ff8e6b269e30a7957bd3244ffbce2a7a35a81b81c382629b80af1a8bf"}, + {file = "aiohttp-3.9.1.tar.gz", hash = "sha256:8fc49a87ac269d4529da45871e2ffb6874e87779c3d0e2ccd813c0899221239d"}, ] [package.dependencies] aiosignal = ">=1.1.2" -async-timeout = ">=4.0.0a3,<5.0" +async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""} attrs = ">=17.3.0" -charset-normalizer = ">=2.0,<4.0" frozenlist = ">=1.1.1" multidict = ">=4.5,<7.0" yarl = ">=1.0,<2.0" [package.extras] -speedups = ["Brotli", "aiodns", "cchardet"] +speedups = ["Brotli", "aiodns", "brotlicffi"] [[package]] name = "aiosignal" @@ -163,13 +153,13 @@ files = [ [[package]] name = "anyio" -version = "4.0.0" +version = "4.1.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false python-versions = ">=3.8" files = [ - {file = "anyio-4.0.0-py3-none-any.whl", hash = "sha256:cfdb2b588b9fc25ede96d8db56ed50848b0b649dca3dd1df0b11f683bb9e0b5f"}, - {file = "anyio-4.0.0.tar.gz", hash = "sha256:f7ed51751b2c2add651e5747c891b47e26d2a21be5d32d9311dfe9692f3e5d7a"}, + {file = "anyio-4.1.0-py3-none-any.whl", hash = "sha256:56a415fbc462291813a94528a779597226619c8e78af7de0507333f700011e5f"}, + {file = "anyio-4.1.0.tar.gz", hash = "sha256:5a0bec7085176715be77df87fc66d6c9d70626bd752fcc85f57cdbee5b3760da"}, ] [package.dependencies] @@ -178,9 +168,9 @@ idna = ">=2.8" sniffio = ">=1.1" [package.extras] -doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] -trio = ["trio (>=0.22)"] +doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +trio = ["trio (>=0.23)"] [[package]] name = "appdirs" @@ -408,29 +398,29 @@ lxml = ["lxml"] [[package]] name = "black" -version = "23.10.1" +version = "23.11.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" files = [ - {file = "black-23.10.1-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:ec3f8e6234c4e46ff9e16d9ae96f4ef69fa328bb4ad08198c8cee45bb1f08c69"}, - {file = "black-23.10.1-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:1b917a2aa020ca600483a7b340c165970b26e9029067f019e3755b56e8dd5916"}, - {file = "black-23.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c74de4c77b849e6359c6f01987e94873c707098322b91490d24296f66d067dc"}, - {file = "black-23.10.1-cp310-cp310-win_amd64.whl", hash = "sha256:7b4d10b0f016616a0d93d24a448100adf1699712fb7a4efd0e2c32bbb219b173"}, - {file = "black-23.10.1-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b15b75fc53a2fbcac8a87d3e20f69874d161beef13954747e053bca7a1ce53a0"}, - {file = "black-23.10.1-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:e293e4c2f4a992b980032bbd62df07c1bcff82d6964d6c9496f2cd726e246ace"}, - {file = "black-23.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d56124b7a61d092cb52cce34182a5280e160e6aff3137172a68c2c2c4b76bcb"}, - {file = "black-23.10.1-cp311-cp311-win_amd64.whl", hash = "sha256:3f157a8945a7b2d424da3335f7ace89c14a3b0625e6593d21139c2d8214d55ce"}, - {file = "black-23.10.1-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:cfcce6f0a384d0da692119f2d72d79ed07c7159879d0bb1bb32d2e443382bf3a"}, - {file = "black-23.10.1-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:33d40f5b06be80c1bbce17b173cda17994fbad096ce60eb22054da021bf933d1"}, - {file = "black-23.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:840015166dbdfbc47992871325799fd2dc0dcf9395e401ada6d88fe11498abad"}, - {file = "black-23.10.1-cp38-cp38-win_amd64.whl", hash = "sha256:037e9b4664cafda5f025a1728c50a9e9aedb99a759c89f760bd83730e76ba884"}, - {file = "black-23.10.1-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:7cb5936e686e782fddb1c73f8aa6f459e1ad38a6a7b0e54b403f1f05a1507ee9"}, - {file = "black-23.10.1-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:7670242e90dc129c539e9ca17665e39a146a761e681805c54fbd86015c7c84f7"}, - {file = "black-23.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ed45ac9a613fb52dad3b61c8dea2ec9510bf3108d4db88422bacc7d1ba1243d"}, - {file = "black-23.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:6d23d7822140e3fef190734216cefb262521789367fbdc0b3f22af6744058982"}, - {file = "black-23.10.1-py3-none-any.whl", hash = "sha256:d431e6739f727bb2e0495df64a6c7a5310758e87505f5f8cde9ff6c0f2d7e4fe"}, - {file = "black-23.10.1.tar.gz", hash = "sha256:1f8ce316753428ff68749c65a5f7844631aa18c8679dfd3ca9dc1a289979c258"}, + {file = "black-23.11.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dbea0bb8575c6b6303cc65017b46351dc5953eea5c0a59d7b7e3a2d2f433a911"}, + {file = "black-23.11.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:412f56bab20ac85927f3a959230331de5614aecda1ede14b373083f62ec24e6f"}, + {file = "black-23.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d136ef5b418c81660ad847efe0e55c58c8208b77a57a28a503a5f345ccf01394"}, + {file = "black-23.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:6c1cac07e64433f646a9a838cdc00c9768b3c362805afc3fce341af0e6a9ae9f"}, + {file = "black-23.11.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cf57719e581cfd48c4efe28543fea3d139c6b6f1238b3f0102a9c73992cbb479"}, + {file = "black-23.11.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:698c1e0d5c43354ec5d6f4d914d0d553a9ada56c85415700b81dc90125aac244"}, + {file = "black-23.11.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:760415ccc20f9e8747084169110ef75d545f3b0932ee21368f63ac0fee86b221"}, + {file = "black-23.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:58e5f4d08a205b11800332920e285bd25e1a75c54953e05502052738fe16b3b5"}, + {file = "black-23.11.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:45aa1d4675964946e53ab81aeec7a37613c1cb71647b5394779e6efb79d6d187"}, + {file = "black-23.11.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4c44b7211a3a0570cc097e81135faa5f261264f4dfaa22bd5ee2875a4e773bd6"}, + {file = "black-23.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a9acad1451632021ee0d146c8765782a0c3846e0e0ea46659d7c4f89d9b212b"}, + {file = "black-23.11.0-cp38-cp38-win_amd64.whl", hash = "sha256:fc7f6a44d52747e65a02558e1d807c82df1d66ffa80a601862040a43ec2e3142"}, + {file = "black-23.11.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7f622b6822f02bfaf2a5cd31fdb7cd86fcf33dab6ced5185c35f5db98260b055"}, + {file = "black-23.11.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:250d7e60f323fcfc8ea6c800d5eba12f7967400eb6c2d21ae85ad31c204fb1f4"}, + {file = "black-23.11.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5133f5507007ba08d8b7b263c7aa0f931af5ba88a29beacc4b2dc23fcefe9c06"}, + {file = "black-23.11.0-cp39-cp39-win_amd64.whl", hash = "sha256:421f3e44aa67138ab1b9bfbc22ee3780b22fa5b291e4db8ab7eee95200726b07"}, + {file = "black-23.11.0-py3-none-any.whl", hash = "sha256:54caaa703227c6e0c87b76326d0862184729a69b73d3b7305b6288e1d830067e"}, + {file = "black-23.11.0.tar.gz", hash = "sha256:4c68855825ff432d197229846f971bc4d6666ce90492e5b02013bcaca4d9ab05"}, ] [package.dependencies] @@ -468,13 +458,13 @@ css = ["tinycss2 (>=1.1.0,<1.3)"] [[package]] name = "certifi" -version = "2023.7.22" +version = "2023.11.17" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" files = [ - {file = "certifi-2023.7.22-py3-none-any.whl", hash = "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9"}, - {file = "certifi-2023.7.22.tar.gz", hash = "sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082"}, + {file = "certifi-2023.11.17-py3-none-any.whl", hash = "sha256:e036ab49d5b79556f99cfc2d9320b34cfbe5be05c5871b51de9329f0603b0474"}, + {file = "certifi-2023.11.17.tar.gz", hash = "sha256:9b469f3a900bf28dc19b8cfbf8019bf47f7fdd1a65a1d4ffb98fc14166beb4d1"}, ] [[package]] @@ -543,112 +533,112 @@ pycparser = "*" [[package]] name = "charset-normalizer" -version = "3.3.1" +version = "3.3.2" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7.0" files = [ - {file = "charset-normalizer-3.3.1.tar.gz", hash = "sha256:d9137a876020661972ca6eec0766d81aef8a5627df628b664b234b73396e727e"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8aee051c89e13565c6bd366813c386939f8e928af93c29fda4af86d25b73d8f8"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:352a88c3df0d1fa886562384b86f9a9e27563d4704ee0e9d56ec6fcd270ea690"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:223b4d54561c01048f657fa6ce41461d5ad8ff128b9678cfe8b2ecd951e3f8a2"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f861d94c2a450b974b86093c6c027888627b8082f1299dfd5a4bae8e2292821"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1171ef1fc5ab4693c5d151ae0fdad7f7349920eabbaca6271f95969fa0756c2d"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28f512b9a33235545fbbdac6a330a510b63be278a50071a336afc1b78781b147"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0e842112fe3f1a4ffcf64b06dc4c61a88441c2f02f373367f7b4c1aa9be2ad5"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f9bc2ce123637a60ebe819f9fccc614da1bcc05798bbbaf2dd4ec91f3e08846"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f194cce575e59ffe442c10a360182a986535fd90b57f7debfaa5c845c409ecc3"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:9a74041ba0bfa9bc9b9bb2cd3238a6ab3b7618e759b41bd15b5f6ad958d17605"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:b578cbe580e3b41ad17b1c428f382c814b32a6ce90f2d8e39e2e635d49e498d1"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:6db3cfb9b4fcecb4390db154e75b49578c87a3b9979b40cdf90d7e4b945656e1"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:debb633f3f7856f95ad957d9b9c781f8e2c6303ef21724ec94bea2ce2fcbd056"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-win32.whl", hash = "sha256:87071618d3d8ec8b186d53cb6e66955ef2a0e4fa63ccd3709c0c90ac5a43520f"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:e372d7dfd154009142631de2d316adad3cc1c36c32a38b16a4751ba78da2a397"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ae4070f741f8d809075ef697877fd350ecf0b7c5837ed68738607ee0a2c572cf"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:58e875eb7016fd014c0eea46c6fa92b87b62c0cb31b9feae25cbbe62c919f54d"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dbd95e300367aa0827496fe75a1766d198d34385a58f97683fe6e07f89ca3e3c"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de0b4caa1c8a21394e8ce971997614a17648f94e1cd0640fbd6b4d14cab13a72"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:985c7965f62f6f32bf432e2681173db41336a9c2611693247069288bcb0c7f8b"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a15c1fe6d26e83fd2e5972425a772cca158eae58b05d4a25a4e474c221053e2d"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ae55d592b02c4349525b6ed8f74c692509e5adffa842e582c0f861751701a673"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:be4d9c2770044a59715eb57c1144dedea7c5d5ae80c68fb9959515037cde2008"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:851cf693fb3aaef71031237cd68699dded198657ec1e76a76eb8be58c03a5d1f"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:31bbaba7218904d2eabecf4feec0d07469284e952a27400f23b6628439439fa7"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:871d045d6ccc181fd863a3cd66ee8e395523ebfbc57f85f91f035f50cee8e3d4"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:501adc5eb6cd5f40a6f77fbd90e5ab915c8fd6e8c614af2db5561e16c600d6f3"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f5fb672c396d826ca16a022ac04c9dce74e00a1c344f6ad1a0fdc1ba1f332213"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-win32.whl", hash = "sha256:bb06098d019766ca16fc915ecaa455c1f1cd594204e7f840cd6258237b5079a8"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:8af5a8917b8af42295e86b64903156b4f110a30dca5f3b5aedea123fbd638bff"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:7ae8e5142dcc7a49168f4055255dbcced01dc1714a90a21f87448dc8d90617d1"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5b70bab78accbc672f50e878a5b73ca692f45f5b5e25c8066d748c09405e6a55"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5ceca5876032362ae73b83347be8b5dbd2d1faf3358deb38c9c88776779b2e2f"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34d95638ff3613849f473afc33f65c401a89f3b9528d0d213c7037c398a51296"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9edbe6a5bf8b56a4a84533ba2b2f489d0046e755c29616ef8830f9e7d9cf5728"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f6a02a3c7950cafaadcd46a226ad9e12fc9744652cc69f9e5534f98b47f3bbcf"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10b8dd31e10f32410751b3430996f9807fc4d1587ca69772e2aa940a82ab571a"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edc0202099ea1d82844316604e17d2b175044f9bcb6b398aab781eba957224bd"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b891a2f68e09c5ef989007fac11476ed33c5c9994449a4e2c3386529d703dc8b"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:71ef3b9be10070360f289aea4838c784f8b851be3ba58cf796262b57775c2f14"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:55602981b2dbf8184c098bc10287e8c245e351cd4fdcad050bd7199d5a8bf514"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:46fb9970aa5eeca547d7aa0de5d4b124a288b42eaefac677bde805013c95725c"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:520b7a142d2524f999447b3a0cf95115df81c4f33003c51a6ab637cbda9d0bf4"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-win32.whl", hash = "sha256:8ec8ef42c6cd5856a7613dcd1eaf21e5573b2185263d87d27c8edcae33b62a61"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:baec8148d6b8bd5cee1ae138ba658c71f5b03e0d69d5907703e3e1df96db5e41"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63a6f59e2d01310f754c270e4a257426fe5a591dc487f1983b3bbe793cf6bac6"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d6bfc32a68bc0933819cfdfe45f9abc3cae3877e1d90aac7259d57e6e0f85b1"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4f3100d86dcd03c03f7e9c3fdb23d92e32abbca07e7c13ebd7ddfbcb06f5991f"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39b70a6f88eebe239fa775190796d55a33cfb6d36b9ffdd37843f7c4c1b5dc67"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e12f8ee80aa35e746230a2af83e81bd6b52daa92a8afaef4fea4a2ce9b9f4fa"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b6cefa579e1237ce198619b76eaa148b71894fb0d6bcf9024460f9bf30fd228"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:61f1e3fb621f5420523abb71f5771a204b33c21d31e7d9d86881b2cffe92c47c"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4f6e2a839f83a6a76854d12dbebde50e4b1afa63e27761549d006fa53e9aa80e"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:1ec937546cad86d0dce5396748bf392bb7b62a9eeb8c66efac60e947697f0e58"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:82ca51ff0fc5b641a2d4e1cc8c5ff108699b7a56d7f3ad6f6da9dbb6f0145b48"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:633968254f8d421e70f91c6ebe71ed0ab140220469cf87a9857e21c16687c034"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-win32.whl", hash = "sha256:c0c72d34e7de5604df0fde3644cc079feee5e55464967d10b24b1de268deceb9"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-win_amd64.whl", hash = "sha256:63accd11149c0f9a99e3bc095bbdb5a464862d77a7e309ad5938fbc8721235ae"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5a3580a4fdc4ac05f9e53c57f965e3594b2f99796231380adb2baaab96e22761"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2465aa50c9299d615d757c1c888bc6fef384b7c4aec81c05a0172b4400f98557"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cb7cd68814308aade9d0c93c5bd2ade9f9441666f8ba5aa9c2d4b389cb5e2a45"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91e43805ccafa0a91831f9cd5443aa34528c0c3f2cc48c4cb3d9a7721053874b"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:854cc74367180beb327ab9d00f964f6d91da06450b0855cbbb09187bcdb02de5"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c15070ebf11b8b7fd1bfff7217e9324963c82dbdf6182ff7050519e350e7ad9f"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c4c99f98fc3a1835af8179dcc9013f93594d0670e2fa80c83aa36346ee763d2"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3fb765362688821404ad6cf86772fc54993ec11577cd5a92ac44b4c2ba52155b"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:dced27917823df984fe0c80a5c4ad75cf58df0fbfae890bc08004cd3888922a2"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a66bcdf19c1a523e41b8e9d53d0cedbfbac2e93c649a2e9502cb26c014d0980c"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ecd26be9f112c4f96718290c10f4caea6cc798459a3a76636b817a0ed7874e42"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:3f70fd716855cd3b855316b226a1ac8bdb3caf4f7ea96edcccc6f484217c9597"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:17a866d61259c7de1bdadef418a37755050ddb4b922df8b356503234fff7932c"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-win32.whl", hash = "sha256:548eefad783ed787b38cb6f9a574bd8664468cc76d1538215d510a3cd41406cb"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:45f053a0ece92c734d874861ffe6e3cc92150e32136dd59ab1fb070575189c97"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bc791ec3fd0c4309a753f95bb6c749ef0d8ea3aea91f07ee1cf06b7b02118f2f"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0c8c61fb505c7dad1d251c284e712d4e0372cef3b067f7ddf82a7fa82e1e9a93"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2c092be3885a1b7899cd85ce24acedc1034199d6fca1483fa2c3a35c86e43041"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2000c54c395d9e5e44c99dc7c20a64dc371f777faf8bae4919ad3e99ce5253e"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4cb50a0335382aac15c31b61d8531bc9bb657cfd848b1d7158009472189f3d62"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c30187840d36d0ba2893bc3271a36a517a717f9fd383a98e2697ee890a37c273"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe81b35c33772e56f4b6cf62cf4aedc1762ef7162a31e6ac7fe5e40d0149eb67"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0bf89afcbcf4d1bb2652f6580e5e55a840fdf87384f6063c4a4f0c95e378656"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:06cf46bdff72f58645434d467bf5228080801298fbba19fe268a01b4534467f5"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:3c66df3f41abee950d6638adc7eac4730a306b022570f71dd0bd6ba53503ab57"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd805513198304026bd379d1d516afbf6c3c13f4382134a2c526b8b854da1c2e"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:9505dc359edb6a330efcd2be825fdb73ee3e628d9010597aa1aee5aa63442e97"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:31445f38053476a0c4e6d12b047b08ced81e2c7c712e5a1ad97bc913256f91b2"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-win32.whl", hash = "sha256:bd28b31730f0e982ace8663d108e01199098432a30a4c410d06fe08fdb9e93f4"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:555fe186da0068d3354cdf4bbcbc609b0ecae4d04c921cc13e209eece7720727"}, - {file = "charset_normalizer-3.3.1-py3-none-any.whl", hash = "sha256:800561453acdecedaac137bf09cd719c7a440b6800ec182f077bb8e7025fb708"}, + {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-win32.whl", hash = "sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-win32.whl", hash = "sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-win32.whl", hash = "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-win32.whl", hash = "sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-win32.whl", hash = "sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-win32.whl", hash = "sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d"}, + {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, ] [[package]] name = "circuitsvis" -version = "1.43.1" +version = "1.43.2" description = "Mechanistic Interpretability Visualizations" optional = false python-versions = ">=3.8" files = [ - {file = "circuitsvis-1.43.1-py3-none-any.whl", hash = "sha256:096138020986f79f1493c0ad8e94107a19b5d19cd5771b4401e231e993267019"}, - {file = "circuitsvis-1.43.1.tar.gz", hash = "sha256:5d730b9ee4c256cdf9c9da598343e7c8cd1eceeacfa8385969c008fa7123d6bc"}, + {file = "circuitsvis-1.43.2-py3-none-any.whl", hash = "sha256:1128fde5de8b738dd3c932d0b0ec4ee5556387b4405592fdf37f617e647183fb"}, + {file = "circuitsvis-1.43.2.tar.gz", hash = "sha256:388c1a6ea1bcf308da51fa6f67be761483ba361321d2e111f4c28faaea458287"}, ] [package.dependencies] @@ -699,22 +689,20 @@ files = [ [[package]] name = "comm" -version = "0.1.4" +version = "0.2.0" description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "comm-0.1.4-py3-none-any.whl", hash = "sha256:6d52794cba11b36ed9860999cd10fd02d6b2eac177068fdd585e1e2f8a96e67a"}, - {file = "comm-0.1.4.tar.gz", hash = "sha256:354e40a59c9dd6db50c5cc6b4acc887d82e9603787f83b68c01a80a923984d15"}, + {file = "comm-0.2.0-py3-none-any.whl", hash = "sha256:2da8d9ebb8dd7bfc247adaff99f24dce705638a8042b85cb995066793e391001"}, + {file = "comm-0.2.0.tar.gz", hash = "sha256:a517ea2ca28931c7007a7a99c562a0fa5883cfb48963140cf642c41c948498be"}, ] [package.dependencies] traitlets = ">=4" [package.extras] -lint = ["black (>=22.6.0)", "mdformat (>0.7)", "mdformat-gfm (>=0.3.5)", "ruff (>=0.0.156)"] test = ["pytest"] -typing = ["mypy (>=0.990)"] [[package]] name = "coverage" @@ -785,25 +773,26 @@ toml = ["tomli"] [[package]] name = "datasets" -version = "2.14.6" +version = "2.15.0" description = "HuggingFace community-driven open-source library of datasets" optional = false python-versions = ">=3.8.0" files = [ - {file = "datasets-2.14.6-py3-none-any.whl", hash = "sha256:4de857ffce21cfc847236745c69f102e33cd1f0fa8398e7be9964525fd4cd5db"}, - {file = "datasets-2.14.6.tar.gz", hash = "sha256:97ebbace8ec7af11434a87d1215379927f8fee2beab2c4a674003756ecfe920c"}, + {file = "datasets-2.15.0-py3-none-any.whl", hash = "sha256:6d658d23811393dfc982d026082e1650bdaaae28f6a86e651966cb072229a228"}, + {file = "datasets-2.15.0.tar.gz", hash = "sha256:a26d059370bd7503bd60e9337977199a13117a83f72fb61eda7e66f0c4d50b2b"}, ] [package.dependencies] aiohttp = "*" dill = ">=0.3.0,<0.3.8" fsspec = {version = ">=2023.1.0,<=2023.10.0", extras = ["http"]} -huggingface-hub = ">=0.14.0,<1.0.0" +huggingface-hub = ">=0.18.0" multiprocess = "*" numpy = ">=1.17" packaging = "*" pandas = "*" pyarrow = ">=8.0.0" +pyarrow-hotfix = "*" pyyaml = ">=5.1" requests = ">=2.19.0" tqdm = ">=4.62.1" @@ -813,15 +802,15 @@ xxhash = "*" apache-beam = ["apache-beam (>=2.26.0,<2.44.0)"] audio = ["librosa", "soundfile (>=0.12.1)"] benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"] -dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "black (>=23.1,<24.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "pyyaml (>=5.3.1)", "rarfile (>=4.0)", "ruff (>=0.0.241)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] +dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "black (>=23.1,<24.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "pyyaml (>=5.3.1)", "rarfile (>=4.0)", "ruff (>=0.0.241)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"] -jax = ["jax (>=0.2.8,!=0.3.2,<=0.3.25)", "jaxlib (>=0.1.65,<=0.3.25)"] +jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"] metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"] quality = ["black (>=23.1,<24.0)", "pyyaml (>=5.3.1)", "ruff (>=0.0.241)"] s3 = ["s3fs"] tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"] tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] -tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] +tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] torch = ["torch"] vision = ["Pillow (>=6.2.1)"] @@ -926,13 +915,13 @@ files = [ [[package]] name = "exceptiongroup" -version = "1.1.3" +version = "1.2.0" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" files = [ - {file = "exceptiongroup-1.1.3-py3-none-any.whl", hash = "sha256:343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3"}, - {file = "exceptiongroup-1.1.3.tar.gz", hash = "sha256:097acd85d473d75af5bb98e41b61ff7fe35efe6675e4f9370ec6ec5126d160e9"}, + {file = "exceptiongroup-1.2.0-py3-none-any.whl", hash = "sha256:4bfd3996ac73b41e9b9628b04e079f193850720ea5945fc96a08633c66912f14"}, + {file = "exceptiongroup-1.2.0.tar.gz", hash = "sha256:91f5c769735f051a4290d52edd0858999b57e5876e9f85937691bd4c9fa3ed68"}, ] [package.extras] @@ -940,13 +929,13 @@ test = ["pytest (>=6)"] [[package]] name = "executing" -version = "2.0.0" +version = "2.0.1" description = "Get the currently executing AST node of a frame, and other information" optional = false -python-versions = "*" +python-versions = ">=3.5" files = [ - {file = "executing-2.0.0-py2.py3-none-any.whl", hash = "sha256:06df6183df67389625f4e763921c6cf978944721abf3e714000200aab95b0657"}, - {file = "executing-2.0.0.tar.gz", hash = "sha256:0ff053696fdeef426cda5bd18eacd94f82c91f49823a2e9090124212ceea9b08"}, + {file = "executing-2.0.1-py2.py3-none-any.whl", hash = "sha256:eac49ca94516ccc753f9fb5ce82603156e590b27525a8bc32cce8ae302eb61bc"}, + {file = "executing-2.0.1.tar.gz", hash = "sha256:35afe2ce3affba8ee97f2d69927fa823b08b472b7b994e36a52a964b93d16147"}, ] [package.extras] @@ -965,13 +954,13 @@ files = [ [[package]] name = "fastjsonschema" -version = "2.18.1" +version = "2.19.0" description = "Fastest Python implementation of JSON schema" optional = false python-versions = "*" files = [ - {file = "fastjsonschema-2.18.1-py3-none-any.whl", hash = "sha256:aec6a19e9f66e9810ab371cc913ad5f4e9e479b63a7072a2cd060a9369e329a8"}, - {file = "fastjsonschema-2.18.1.tar.gz", hash = "sha256:06dc8680d937628e993fa0cd278f196d20449a1adc087640710846b324d422ea"}, + {file = "fastjsonschema-2.19.0-py3-none-any.whl", hash = "sha256:b9fd1a2dd6971dbc7fee280a95bd199ae0dd9ce22beb91cc75e9c1c528a5170e"}, + {file = "fastjsonschema-2.19.0.tar.gz", hash = "sha256:e25df6647e1bc4a26070b700897b07b542ec898dd4f1f6ea013e7f6a88417225"}, ] [package.extras] @@ -979,19 +968,19 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc [[package]] name = "filelock" -version = "3.12.4" +version = "3.13.1" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.12.4-py3-none-any.whl", hash = "sha256:08c21d87ded6e2b9da6728c3dff51baf1dcecf973b768ef35bcbc3447edb9ad4"}, - {file = "filelock-3.12.4.tar.gz", hash = "sha256:2e6f249f1f3654291606e046b09f1fd5eac39b360664c27f5aad072012f8bcbd"}, + {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, + {file = "filelock-3.13.1.tar.gz", hash = "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e"}, ] [package.extras] -docs = ["furo (>=2023.7.26)", "sphinx (>=7.1.2)", "sphinx-autodoc-typehints (>=1.24)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3)", "diff-cover (>=7.7)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)", "pytest-timeout (>=2.1)"] -typing = ["typing-extensions (>=4.7.1)"] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.24)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +typing = ["typing-extensions (>=4.8)"] [[package]] name = "fqdn" @@ -1163,18 +1152,18 @@ test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre [[package]] name = "huggingface-hub" -version = "0.17.3" +version = "0.19.4" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.17.3-py3-none-any.whl", hash = "sha256:545eb3665f6ac587add946e73984148f2ea5c7877eac2e845549730570c1933a"}, - {file = "huggingface_hub-0.17.3.tar.gz", hash = "sha256:40439632b211311f788964602bf8b0d9d6b7a2314fba4e8d67b2ce3ecea0e3fd"}, + {file = "huggingface_hub-0.19.4-py3-none-any.whl", hash = "sha256:dba013f779da16f14b606492828f3760600a1e1801432d09fe1c33e50b825bb5"}, + {file = "huggingface_hub-0.19.4.tar.gz", hash = "sha256:176a4fc355a851c17550e7619488f383189727eab209534d7cef2114dae77b22"}, ] [package.dependencies] filelock = "*" -fsspec = "*" +fsspec = ">=2023.5.0" packaging = ">=20.9" pyyaml = ">=5.1" requests = "*" @@ -1182,27 +1171,27 @@ tqdm = ">=4.42.1" typing-extensions = ">=3.7.4.3" [package.extras] -all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (==23.7)", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (<2.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] cli = ["InquirerPy (==0.3.4)"] -dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (==23.7)", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (<2.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)"] -docs = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "black (==23.7)", "gradio", "hf-doc-builder", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (<2.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.0.241)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "urllib3 (<2.0)", "watchdog"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +docs = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "hf-doc-builder", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)", "watchdog"] fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] -inference = ["aiohttp", "pydantic (<2.0)"] -quality = ["black (==23.7)", "mypy (==1.5.1)", "ruff (>=0.0.241)"] +inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"] +quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"] tensorflow = ["graphviz", "pydot", "tensorflow"] -testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (<2.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] torch = ["torch"] -typing = ["pydantic (<2.0)", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"] +typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] [[package]] name = "idna" -version = "3.4" +version = "3.6" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.5" files = [ - {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, - {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, + {file = "idna-3.6-py3-none-any.whl", hash = "sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f"}, + {file = "idna-3.6.tar.gz", hash = "sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca"}, ] [[package]] @@ -1218,32 +1207,32 @@ files = [ [[package]] name = "importlib-metadata" -version = "6.8.0" +version = "7.0.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-6.8.0-py3-none-any.whl", hash = "sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb"}, - {file = "importlib_metadata-6.8.0.tar.gz", hash = "sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743"}, + {file = "importlib_metadata-7.0.0-py3-none-any.whl", hash = "sha256:d97503976bb81f40a193d41ee6570868479c69d5068651eb039c40d850c59d67"}, + {file = "importlib_metadata-7.0.0.tar.gz", hash = "sha256:7fc841f8b8332803464e5dc1c63a2e59121f46ca186c0e2e182e80bf8c1319f7"}, ] [package.dependencies] zipp = ">=0.5" [package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] [[package]] name = "importlib-resources" -version = "6.1.0" +version = "6.1.1" description = "Read resources from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_resources-6.1.0-py3-none-any.whl", hash = "sha256:aa50258bbfa56d4e33fbd8aa3ef48ded10d1735f11532b8df95388cc6bdb7e83"}, - {file = "importlib_resources-6.1.0.tar.gz", hash = "sha256:9d48dcccc213325e810fd723e7fbb45ccb39f6cf5c31f00cf2b965f5f10f3cb9"}, + {file = "importlib_resources-6.1.1-py3-none-any.whl", hash = "sha256:e8bf90d8213b486f428c9c39714b920041cb02c184686a3dee24905aaa8105d6"}, + {file = "importlib_resources-6.1.1.tar.gz", hash = "sha256:3893a00122eafde6894c59914446a512f728a0c1a45f9bb9b63721b6bacf0b4a"}, ] [package.dependencies] @@ -1266,13 +1255,13 @@ files = [ [[package]] name = "ipykernel" -version = "6.26.0" +version = "6.27.1" description = "IPython Kernel for Jupyter" optional = false python-versions = ">=3.8" files = [ - {file = "ipykernel-6.26.0-py3-none-any.whl", hash = "sha256:3ba3dc97424b87b31bb46586b5167b3161b32d7820b9201a9e698c71e271602c"}, - {file = "ipykernel-6.26.0.tar.gz", hash = "sha256:553856658eb8430bbe9653ea041a41bff63e9606fc4628873fc92a6cf3abd404"}, + {file = "ipykernel-6.27.1-py3-none-any.whl", hash = "sha256:dab88b47f112f9f7df62236511023c9bdeef67abc73af7c652e4ce4441601686"}, + {file = "ipykernel-6.27.1.tar.gz", hash = "sha256:7d5d594b6690654b4d299edba5e872dc17bb7396a8d0609c97cb7b8a1c605de6"}, ] [package.dependencies] @@ -1336,17 +1325,6 @@ qtconsole = ["qtconsole"] test = ["pytest (<7.1)", "pytest-asyncio", "testpath"] test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.21)", "pandas", "pytest (<7.1)", "pytest-asyncio", "testpath", "trio"] -[[package]] -name = "ipython-genutils" -version = "0.2.0" -description = "Vestigial utilities from IPython" -optional = false -python-versions = "*" -files = [ - {file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"}, - {file = "ipython_genutils-0.2.0.tar.gz", hash = "sha256:eb2e116e75ecef9d4d228fdc66af54269afa26ab4463042e33785b887c628ba8"}, -] - [[package]] name = "ipywidgets" version = "8.1.1" @@ -1477,13 +1455,13 @@ files = [ [[package]] name = "jsonschema" -version = "4.19.1" +version = "4.20.0" description = "An implementation of JSON Schema validation for Python" optional = false python-versions = ">=3.8" files = [ - {file = "jsonschema-4.19.1-py3-none-any.whl", hash = "sha256:cd5f1f9ed9444e554b38ba003af06c0a8c2868131e56bfbef0550fb450c0330e"}, - {file = "jsonschema-4.19.1.tar.gz", hash = "sha256:ec84cc37cfa703ef7cd4928db24f9cb31428a5d0fa77747b8b51a847458e0bbf"}, + {file = "jsonschema-4.20.0-py3-none-any.whl", hash = "sha256:ed6231f0429ecf966f5bc8dfef245998220549cbbcf140f913b7464c52c3b6b3"}, + {file = "jsonschema-4.20.0.tar.gz", hash = "sha256:4f614fd46d8d61258610998997743ec5492a648b33cf478c1ddc23ed4598a5fa"}, ] [package.dependencies] @@ -1508,18 +1486,18 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- [[package]] name = "jsonschema-specifications" -version = "2023.7.1" +version = "2023.11.2" description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" optional = false python-versions = ">=3.8" files = [ - {file = "jsonschema_specifications-2023.7.1-py3-none-any.whl", hash = "sha256:05adf340b659828a004220a9613be00fa3f223f2b82002e273dee62fd50524b1"}, - {file = "jsonschema_specifications-2023.7.1.tar.gz", hash = "sha256:c91a50404e88a1f6ba40636778e2ee08f6e24c5613fe4c53ac24578a5a7f72bb"}, + {file = "jsonschema_specifications-2023.11.2-py3-none-any.whl", hash = "sha256:e74ba7c0a65e8cb49dc26837d6cfe576557084a8b423ed16a420984228104f93"}, + {file = "jsonschema_specifications-2023.11.2.tar.gz", hash = "sha256:9472fc4fea474cd74bea4a2b190daeccb5a9e4db2ea80efcf7a1b582fc9a81b8"}, ] [package.dependencies] importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""} -referencing = ">=0.28.0" +referencing = ">=0.31.0" [[package]] name = "jupyter" @@ -1543,13 +1521,13 @@ qtconsole = "*" [[package]] name = "jupyter-client" -version = "8.5.0" +version = "8.6.0" description = "Jupyter protocol implementation and client libraries" optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_client-8.5.0-py3-none-any.whl", hash = "sha256:c3877aac7257ec68d79b5c622ce986bd2a992ca42f6ddc9b4dd1da50e89f7028"}, - {file = "jupyter_client-8.5.0.tar.gz", hash = "sha256:e8754066510ce456358df363f97eae64b50860f30dc1fe8c6771440db3be9a63"}, + {file = "jupyter_client-8.6.0-py3-none-any.whl", hash = "sha256:909c474dbe62582ae62b758bca86d6518c85234bdee2d908c778db6d72f39d99"}, + {file = "jupyter_client-8.6.0.tar.gz", hash = "sha256:0642244bb83b4764ae60d07e010e15f0e2d275ec4e918a8f7b80fbbef3ca60c7"}, ] [package.dependencies] @@ -1590,13 +1568,13 @@ test = ["flaky", "pexpect", "pytest"] [[package]] name = "jupyter-core" -version = "5.4.0" +version = "5.5.0" description = "Jupyter core package. A base package on which Jupyter projects rely." optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_core-5.4.0-py3-none-any.whl", hash = "sha256:66e252f675ac04dcf2feb6ed4afb3cd7f68cf92f483607522dc251f32d471571"}, - {file = "jupyter_core-5.4.0.tar.gz", hash = "sha256:e4b98344bb94ee2e3e6c4519a97d001656009f9cb2b7f2baf15b3c205770011d"}, + {file = "jupyter_core-5.5.0-py3-none-any.whl", hash = "sha256:e11e02cd8ae0a9de5c6c44abf5727df9f2581055afe00b22183f621ba3585805"}, + {file = "jupyter_core-5.5.0.tar.gz", hash = "sha256:880b86053bf298a8724994f95e99b99130659022a4f7f45f563084b6223861d3"}, ] [package.dependencies] @@ -1605,18 +1583,18 @@ pywin32 = {version = ">=300", markers = "sys_platform == \"win32\" and platform_ traitlets = ">=5.3" [package.extras] -docs = ["myst-parser", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"] +docs = ["myst-parser", "pydata-sphinx-theme", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"] test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] [[package]] name = "jupyter-events" -version = "0.8.0" +version = "0.9.0" description = "Jupyter Event System library" optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_events-0.8.0-py3-none-any.whl", hash = "sha256:81f07375c7673ff298bfb9302b4a981864ec64edaed75ca0fe6f850b9b045525"}, - {file = "jupyter_events-0.8.0.tar.gz", hash = "sha256:fda08f0defce5e16930542ce60634ba48e010830d50073c3dfd235759cee77bf"}, + {file = "jupyter_events-0.9.0-py3-none-any.whl", hash = "sha256:d853b3c10273ff9bc8bb8b30076d65e2c9685579db736873de6c2232dde148bf"}, + {file = "jupyter_events-0.9.0.tar.gz", hash = "sha256:81ad2e4bc710881ec274d31c6c50669d71bbaa5dd9d01e600b56faa85700d399"}, ] [package.dependencies] @@ -1635,13 +1613,13 @@ test = ["click", "pre-commit", "pytest (>=7.0)", "pytest-asyncio (>=0.19.0)", "p [[package]] name = "jupyter-lsp" -version = "2.2.0" +version = "2.2.1" description = "Multi-Language Server WebSocket proxy for Jupyter Notebook/Lab server" optional = false python-versions = ">=3.8" files = [ - {file = "jupyter-lsp-2.2.0.tar.gz", hash = "sha256:8ebbcb533adb41e5d635eb8fe82956b0aafbf0fd443b6c4bfa906edeeb8635a1"}, - {file = "jupyter_lsp-2.2.0-py3-none-any.whl", hash = "sha256:9e06b8b4f7dd50300b70dd1a78c0c3b0c3d8fa68e0f2d8a5d1fbab62072aca3f"}, + {file = "jupyter-lsp-2.2.1.tar.gz", hash = "sha256:b17fab6d70fe83c8896b0cff59237640038247c196056b43684a0902b6a9e0fb"}, + {file = "jupyter_lsp-2.2.1-py3-none-any.whl", hash = "sha256:17a689910c5e4ae5e7d334b02f31d08ffbe98108f6f658fb05e4304b4345368b"}, ] [package.dependencies] @@ -1650,13 +1628,13 @@ jupyter-server = ">=1.1.2" [[package]] name = "jupyter-server" -version = "2.9.1" +version = "2.12.1" description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications." optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_server-2.9.1-py3-none-any.whl", hash = "sha256:21ad1a3d455d5a79ce4bef5201925cd17510c17898cf9d54e3ccfb6b12734948"}, - {file = "jupyter_server-2.9.1.tar.gz", hash = "sha256:9ba71be4b9c16e479e4c50c929f8ac4b1015baf90237a08681397a98c76c7e5e"}, + {file = "jupyter_server-2.12.1-py3-none-any.whl", hash = "sha256:fd030dd7be1ca572e4598203f718df6630c12bd28a599d7f1791c4d7938e1010"}, + {file = "jupyter_server-2.12.1.tar.gz", hash = "sha256:dc77b7dcc5fc0547acba2b2844f01798008667201eea27c6319ff9257d700a6d"}, ] [package.dependencies] @@ -1665,7 +1643,7 @@ argon2-cffi = "*" jinja2 = "*" jupyter-client = ">=7.4.4" jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" -jupyter-events = ">=0.6.0" +jupyter-events = ">=0.9.0" jupyter-server-terminals = "*" nbconvert = ">=6.4.4" nbformat = ">=5.3.0" @@ -1705,13 +1683,13 @@ test = ["coverage", "jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-cov", [[package]] name = "jupyterlab" -version = "4.0.7" +version = "4.0.9" description = "JupyterLab computational environment" optional = false python-versions = ">=3.8" files = [ - {file = "jupyterlab-4.0.7-py3-none-any.whl", hash = "sha256:08683045117cc495531fdb39c22ababb9aaac6977a45e67cfad20046564c9c7c"}, - {file = "jupyterlab-4.0.7.tar.gz", hash = "sha256:48792efd9f962b2bcda1f87d72168ff122c288b1d97d32109e4a11b33dc862be"}, + {file = "jupyterlab-4.0.9-py3-none-any.whl", hash = "sha256:9f6f8e36d543fdbcc3df961a1d6a3f524b4a4001be0327a398f68fa4e534107c"}, + {file = "jupyterlab-4.0.9.tar.gz", hash = "sha256:9ebada41d52651f623c0c9f069ddb8a21d6848e4c887d8e5ddc0613166ed5c0b"}, ] [package.dependencies] @@ -1731,31 +1709,31 @@ tornado = ">=6.2.0" traitlets = "*" [package.extras] -dev = ["black[jupyter] (==23.7.0)", "build", "bump2version", "coverage", "hatch", "pre-commit", "pytest-cov", "ruff (==0.0.286)"] +dev = ["black[jupyter] (==23.10.1)", "build", "bump2version", "coverage", "hatch", "pre-commit", "pytest-cov", "ruff (==0.1.4)"] docs = ["jsx-lexer", "myst-parser", "pydata-sphinx-theme (>=0.13.0)", "pytest", "pytest-check-links", "pytest-tornasync", "sphinx (>=1.8,<7.2.0)", "sphinx-copybutton"] docs-screenshots = ["altair (==5.0.1)", "ipython (==8.14.0)", "ipywidgets (==8.0.6)", "jupyterlab-geojson (==3.4.0)", "jupyterlab-language-pack-zh-cn (==4.0.post0)", "matplotlib (==3.7.1)", "nbconvert (>=7.0.0)", "pandas (==2.0.2)", "scipy (==1.10.1)", "vega-datasets (==0.9.0)"] test = ["coverage", "pytest (>=7.0)", "pytest-check-links (>=0.7)", "pytest-console-scripts", "pytest-cov", "pytest-jupyter (>=0.5.3)", "pytest-timeout", "pytest-tornasync", "requests", "requests-cache", "virtualenv"] [[package]] name = "jupyterlab-pygments" -version = "0.2.2" +version = "0.3.0" description = "Pygments theme using JupyterLab CSS variables" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "jupyterlab_pygments-0.2.2-py2.py3-none-any.whl", hash = "sha256:2405800db07c9f770863bcf8049a529c3dd4d3e28536638bd7c1c01d2748309f"}, - {file = "jupyterlab_pygments-0.2.2.tar.gz", hash = "sha256:7405d7fde60819d905a9fa8ce89e4cd830e318cdad22a0030f7a901da705585d"}, + {file = "jupyterlab_pygments-0.3.0-py3-none-any.whl", hash = "sha256:841a89020971da1d8693f1a99997aefc5dc424bb1b251fd6322462a1b8842780"}, + {file = "jupyterlab_pygments-0.3.0.tar.gz", hash = "sha256:721aca4d9029252b11cfa9d185e5b5af4d54772bb8072f9b7036f4170054d35d"}, ] [[package]] name = "jupyterlab-server" -version = "2.25.0" +version = "2.25.2" description = "A set of server components for JupyterLab and JupyterLab like applications." optional = false python-versions = ">=3.8" files = [ - {file = "jupyterlab_server-2.25.0-py3-none-any.whl", hash = "sha256:c9f67a98b295c5dee87f41551b0558374e45d449f3edca153dd722140630dcb2"}, - {file = "jupyterlab_server-2.25.0.tar.gz", hash = "sha256:77c2f1f282d610f95e496e20d5bf1d2a7706826dfb7b18f3378ae2870d272fb7"}, + {file = "jupyterlab_server-2.25.2-py3-none-any.whl", hash = "sha256:5b1798c9cc6a44f65c757de9f97fc06fc3d42535afbf47d2ace5e964ab447aaf"}, + {file = "jupyterlab_server-2.25.2.tar.gz", hash = "sha256:bd0ec7a99ebcedc8bcff939ef86e52c378e44c2707e053fcd81d046ce979ee63"}, ] [package.dependencies] @@ -1771,7 +1749,7 @@ requests = ">=2.31" [package.extras] docs = ["autodoc-traits", "jinja2 (<3.2.0)", "mistune (<4)", "myst-parser", "pydata-sphinx-theme", "sphinx", "sphinx-copybutton", "sphinxcontrib-openapi (>0.8)"] openapi = ["openapi-core (>=0.18.0,<0.19.0)", "ruamel-yaml"] -test = ["hatch", "ipykernel", "openapi-core (>=0.18.0,<0.19.0)", "openapi-spec-validator (>=0.6.0,<0.7.0)", "pytest (>=7.0)", "pytest-console-scripts", "pytest-cov", "pytest-jupyter[server] (>=0.6.2)", "pytest-timeout", "requests-mock", "ruamel-yaml", "sphinxcontrib-spelling", "strict-rfc3339", "werkzeug"] +test = ["hatch", "ipykernel", "openapi-core (>=0.18.0,<0.19.0)", "openapi-spec-validator (>=0.6.0,<0.8.0)", "pytest (>=7.0)", "pytest-console-scripts", "pytest-cov", "pytest-jupyter[server] (>=0.6.2)", "pytest-timeout", "requests-mock", "ruamel-yaml", "sphinxcontrib-spelling", "strict-rfc3339", "werkzeug"] [[package]] name = "jupyterlab-widgets" @@ -1898,6 +1876,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -2115,38 +2103,38 @@ dill = ">=0.3.7" [[package]] name = "mypy" -version = "1.6.1" +version = "1.7.1" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" files = [ - {file = "mypy-1.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e5012e5cc2ac628177eaac0e83d622b2dd499e28253d4107a08ecc59ede3fc2c"}, - {file = "mypy-1.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d8fbb68711905f8912e5af474ca8b78d077447d8f3918997fecbf26943ff3cbb"}, - {file = "mypy-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21a1ad938fee7d2d96ca666c77b7c494c3c5bd88dff792220e1afbebb2925b5e"}, - {file = "mypy-1.6.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b96ae2c1279d1065413965c607712006205a9ac541895004a1e0d4f281f2ff9f"}, - {file = "mypy-1.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:40b1844d2e8b232ed92e50a4bd11c48d2daa351f9deee6c194b83bf03e418b0c"}, - {file = "mypy-1.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:81af8adaa5e3099469e7623436881eff6b3b06db5ef75e6f5b6d4871263547e5"}, - {file = "mypy-1.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8c223fa57cb154c7eab5156856c231c3f5eace1e0bed9b32a24696b7ba3c3245"}, - {file = "mypy-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8032e00ce71c3ceb93eeba63963b864bf635a18f6c0c12da6c13c450eedb183"}, - {file = "mypy-1.6.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4c46b51de523817a0045b150ed11b56f9fff55f12b9edd0f3ed35b15a2809de0"}, - {file = "mypy-1.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:19f905bcfd9e167159b3d63ecd8cb5e696151c3e59a1742e79bc3bcb540c42c7"}, - {file = "mypy-1.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:82e469518d3e9a321912955cc702d418773a2fd1e91c651280a1bda10622f02f"}, - {file = "mypy-1.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d4473c22cc296425bbbce7e9429588e76e05bc7342da359d6520b6427bf76660"}, - {file = "mypy-1.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59a0d7d24dfb26729e0a068639a6ce3500e31d6655df8557156c51c1cb874ce7"}, - {file = "mypy-1.6.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cfd13d47b29ed3bbaafaff7d8b21e90d827631afda134836962011acb5904b71"}, - {file = "mypy-1.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:eb4f18589d196a4cbe5290b435d135dee96567e07c2b2d43b5c4621b6501531a"}, - {file = "mypy-1.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:41697773aa0bf53ff917aa077e2cde7aa50254f28750f9b88884acea38a16169"}, - {file = "mypy-1.6.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7274b0c57737bd3476d2229c6389b2ec9eefeb090bbaf77777e9d6b1b5a9d143"}, - {file = "mypy-1.6.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbaf4662e498c8c2e352da5f5bca5ab29d378895fa2d980630656178bd607c46"}, - {file = "mypy-1.6.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bb8ccb4724f7d8601938571bf3f24da0da791fe2db7be3d9e79849cb64e0ae85"}, - {file = "mypy-1.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:68351911e85145f582b5aa6cd9ad666c8958bcae897a1bfda8f4940472463c45"}, - {file = "mypy-1.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:49ae115da099dcc0922a7a895c1eec82c1518109ea5c162ed50e3b3594c71208"}, - {file = "mypy-1.6.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8b27958f8c76bed8edaa63da0739d76e4e9ad4ed325c814f9b3851425582a3cd"}, - {file = "mypy-1.6.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:925cd6a3b7b55dfba252b7c4561892311c5358c6b5a601847015a1ad4eb7d332"}, - {file = "mypy-1.6.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8f57e6b6927a49550da3d122f0cb983d400f843a8a82e65b3b380d3d7259468f"}, - {file = "mypy-1.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:a43ef1c8ddfdb9575691720b6352761f3f53d85f1b57d7745701041053deff30"}, - {file = "mypy-1.6.1-py3-none-any.whl", hash = "sha256:4cbe68ef919c28ea561165206a2dcb68591c50f3bcf777932323bc208d949cf1"}, - {file = "mypy-1.6.1.tar.gz", hash = "sha256:4d01c00d09a0be62a4ca3f933e315455bde83f37f892ba4b08ce92f3cf44bcc1"}, + {file = "mypy-1.7.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:12cce78e329838d70a204293e7b29af9faa3ab14899aec397798a4b41be7f340"}, + {file = "mypy-1.7.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1484b8fa2c10adf4474f016e09d7a159602f3239075c7bf9f1627f5acf40ad49"}, + {file = "mypy-1.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31902408f4bf54108bbfb2e35369877c01c95adc6192958684473658c322c8a5"}, + {file = "mypy-1.7.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f2c2521a8e4d6d769e3234350ba7b65ff5d527137cdcde13ff4d99114b0c8e7d"}, + {file = "mypy-1.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:fcd2572dd4519e8a6642b733cd3a8cfc1ef94bafd0c1ceed9c94fe736cb65b6a"}, + {file = "mypy-1.7.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4b901927f16224d0d143b925ce9a4e6b3a758010673eeded9b748f250cf4e8f7"}, + {file = "mypy-1.7.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2f7f6985d05a4e3ce8255396df363046c28bea790e40617654e91ed580ca7c51"}, + {file = "mypy-1.7.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:944bdc21ebd620eafefc090cdf83158393ec2b1391578359776c00de00e8907a"}, + {file = "mypy-1.7.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9c7ac372232c928fff0645d85f273a726970c014749b924ce5710d7d89763a28"}, + {file = "mypy-1.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:f6efc9bd72258f89a3816e3a98c09d36f079c223aa345c659622f056b760ab42"}, + {file = "mypy-1.7.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6dbdec441c60699288adf051f51a5d512b0d818526d1dcfff5a41f8cd8b4aaf1"}, + {file = "mypy-1.7.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4fc3d14ee80cd22367caaaf6e014494415bf440980a3045bf5045b525680ac33"}, + {file = "mypy-1.7.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c6e4464ed5f01dc44dc9821caf67b60a4e5c3b04278286a85c067010653a0eb"}, + {file = "mypy-1.7.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:d9b338c19fa2412f76e17525c1b4f2c687a55b156320acb588df79f2e6fa9fea"}, + {file = "mypy-1.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:204e0d6de5fd2317394a4eff62065614c4892d5a4d1a7ee55b765d7a3d9e3f82"}, + {file = "mypy-1.7.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:84860e06ba363d9c0eeabd45ac0fde4b903ad7aa4f93cd8b648385a888e23200"}, + {file = "mypy-1.7.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8c5091ebd294f7628eb25ea554852a52058ac81472c921150e3a61cdd68f75a7"}, + {file = "mypy-1.7.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40716d1f821b89838589e5b3106ebbc23636ffdef5abc31f7cd0266db936067e"}, + {file = "mypy-1.7.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5cf3f0c5ac72139797953bd50bc6c95ac13075e62dbfcc923571180bebb662e9"}, + {file = "mypy-1.7.1-cp38-cp38-win_amd64.whl", hash = "sha256:78e25b2fd6cbb55ddfb8058417df193f0129cad5f4ee75d1502248e588d9e0d7"}, + {file = "mypy-1.7.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:75c4d2a6effd015786c87774e04331b6da863fc3fc4e8adfc3b40aa55ab516fe"}, + {file = "mypy-1.7.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2643d145af5292ee956aa0a83c2ce1038a3bdb26e033dadeb2f7066fb0c9abce"}, + {file = "mypy-1.7.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75aa828610b67462ffe3057d4d8a4112105ed211596b750b53cbfe182f44777a"}, + {file = "mypy-1.7.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ee5d62d28b854eb61889cde4e1dbc10fbaa5560cb39780c3995f6737f7e82120"}, + {file = "mypy-1.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:72cf32ce7dd3562373f78bd751f73c96cfb441de147cc2448a92c1a308bd0ca6"}, + {file = "mypy-1.7.1-py3-none-any.whl", hash = "sha256:f7c5d642db47376a0cc130f0de6d055056e010debdaf0707cd2b0fc7e7ef30ea"}, + {file = "mypy-1.7.1.tar.gz", hash = "sha256:fcb6d9afb1b6208b4c712af0dafdc650f518836065df0d4fb1d800f5d6773db2"}, ] [package.dependencies] @@ -2157,6 +2145,7 @@ typing-extensions = ">=4.1.0" [package.extras] dmypy = ["psutil (>=4.0)"] install-types = ["pip"] +mypyc = ["setuptools (>=50)"] reports = ["lxml"] [[package]] @@ -2198,13 +2187,13 @@ testing-docutils = ["pygments", "pytest (>=7,<8)", "pytest-param-files (>=0.3.4, [[package]] name = "nbclient" -version = "0.8.0" +version = "0.9.0" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." optional = false python-versions = ">=3.8.0" files = [ - {file = "nbclient-0.8.0-py3-none-any.whl", hash = "sha256:25e861299e5303a0477568557c4045eccc7a34c17fc08e7959558707b9ebe548"}, - {file = "nbclient-0.8.0.tar.gz", hash = "sha256:f9b179cd4b2d7bca965f900a2ebf0db4a12ebff2f36a711cb66861e4ae158e55"}, + {file = "nbclient-0.9.0-py3-none-any.whl", hash = "sha256:a3a1ddfb34d4a9d17fc744d655962714a866639acd30130e9be84191cd97cd15"}, + {file = "nbclient-0.9.0.tar.gz", hash = "sha256:4b28c207877cf33ef3a9838cdc7a54c5ceff981194a82eac59d558f05487295e"}, ] [package.dependencies] @@ -2220,13 +2209,13 @@ test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>= [[package]] name = "nbconvert" -version = "7.9.2" +version = "7.12.0" description = "Converting Jupyter Notebooks" optional = false python-versions = ">=3.8" files = [ - {file = "nbconvert-7.9.2-py3-none-any.whl", hash = "sha256:39fe4b8bdd1b0104fdd86fc8a43a9077ba64c720bda4c6132690d917a0a154ee"}, - {file = "nbconvert-7.9.2.tar.gz", hash = "sha256:e56cc7588acc4f93e2bb5a34ec69028e4941797b2bfaf6462f18a41d1cc258c9"}, + {file = "nbconvert-7.12.0-py3-none-any.whl", hash = "sha256:5b6c848194d270cc55fb691169202620d7b52a12fec259508d142ecbe4219310"}, + {file = "nbconvert-7.12.0.tar.gz", hash = "sha256:b1564bd89f69a74cd6398b0362da94db07aafb991b7857216a766204a71612c0"}, ] [package.dependencies] @@ -2253,7 +2242,7 @@ docs = ["ipykernel", "ipython", "myst-parser", "nbsphinx (>=0.2.12)", "pydata-sp qtpdf = ["nbconvert[qtpng]"] qtpng = ["pyqtwebengine (>=5.15)"] serve = ["tornado (>=6.1)"] -test = ["flaky", "ipykernel", "ipywidgets (>=7)", "pytest", "pytest-dependency"] +test = ["flaky", "ipykernel", "ipywidgets (>=7)", "pytest"] webpdf = ["playwright"] [[package]] @@ -2422,43 +2411,47 @@ files = [ [[package]] name = "numpy" -version = "1.26.1" +version = "1.26.2" description = "Fundamental package for array computing in Python" optional = false -python-versions = "<3.13,>=3.9" -files = [ - {file = "numpy-1.26.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:82e871307a6331b5f09efda3c22e03c095d957f04bf6bc1804f30048d0e5e7af"}, - {file = "numpy-1.26.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cdd9ec98f0063d93baeb01aad472a1a0840dee302842a2746a7a8e92968f9575"}, - {file = "numpy-1.26.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d78f269e0c4fd365fc2992c00353e4530d274ba68f15e968d8bc3c69ce5f5244"}, - {file = "numpy-1.26.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ab9163ca8aeb7fd32fe93866490654d2f7dda4e61bc6297bf72ce07fdc02f67"}, - {file = "numpy-1.26.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:78ca54b2f9daffa5f323f34cdf21e1d9779a54073f0018a3094ab907938331a2"}, - {file = "numpy-1.26.1-cp310-cp310-win32.whl", hash = "sha256:d1cfc92db6af1fd37a7bb58e55c8383b4aa1ba23d012bdbba26b4bcca45ac297"}, - {file = "numpy-1.26.1-cp310-cp310-win_amd64.whl", hash = "sha256:d2984cb6caaf05294b8466966627e80bf6c7afd273279077679cb010acb0e5ab"}, - {file = "numpy-1.26.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cd7837b2b734ca72959a1caf3309457a318c934abef7a43a14bb984e574bbb9a"}, - {file = "numpy-1.26.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1c59c046c31a43310ad0199d6299e59f57a289e22f0f36951ced1c9eac3665b9"}, - {file = "numpy-1.26.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d58e8c51a7cf43090d124d5073bc29ab2755822181fcad978b12e144e5e5a4b3"}, - {file = "numpy-1.26.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6081aed64714a18c72b168a9276095ef9155dd7888b9e74b5987808f0dd0a974"}, - {file = "numpy-1.26.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:97e5d6a9f0702c2863aaabf19f0d1b6c2628fbe476438ce0b5ce06e83085064c"}, - {file = "numpy-1.26.1-cp311-cp311-win32.whl", hash = "sha256:b9d45d1dbb9de84894cc50efece5b09939752a2d75aab3a8b0cef6f3a35ecd6b"}, - {file = "numpy-1.26.1-cp311-cp311-win_amd64.whl", hash = "sha256:3649d566e2fc067597125428db15d60eb42a4e0897fc48d28cb75dc2e0454e53"}, - {file = "numpy-1.26.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1d1bd82d539607951cac963388534da3b7ea0e18b149a53cf883d8f699178c0f"}, - {file = "numpy-1.26.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:afd5ced4e5a96dac6725daeb5242a35494243f2239244fad10a90ce58b071d24"}, - {file = "numpy-1.26.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a03fb25610ef560a6201ff06df4f8105292ba56e7cdd196ea350d123fc32e24e"}, - {file = "numpy-1.26.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcfaf015b79d1f9f9c9fd0731a907407dc3e45769262d657d754c3a028586124"}, - {file = "numpy-1.26.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e509cbc488c735b43b5ffea175235cec24bbc57b227ef1acc691725beb230d1c"}, - {file = "numpy-1.26.1-cp312-cp312-win32.whl", hash = "sha256:af22f3d8e228d84d1c0c44c1fbdeb80f97a15a0abe4f080960393a00db733b66"}, - {file = "numpy-1.26.1-cp312-cp312-win_amd64.whl", hash = "sha256:9f42284ebf91bdf32fafac29d29d4c07e5e9d1af862ea73686581773ef9e73a7"}, - {file = "numpy-1.26.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bb894accfd16b867d8643fc2ba6c8617c78ba2828051e9a69511644ce86ce83e"}, - {file = "numpy-1.26.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e44ccb93f30c75dfc0c3aa3ce38f33486a75ec9abadabd4e59f114994a9c4617"}, - {file = "numpy-1.26.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9696aa2e35cc41e398a6d42d147cf326f8f9d81befcb399bc1ed7ffea339b64e"}, - {file = "numpy-1.26.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a5b411040beead47a228bde3b2241100454a6abde9df139ed087bd73fc0a4908"}, - {file = "numpy-1.26.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1e11668d6f756ca5ef534b5be8653d16c5352cbb210a5c2a79ff288e937010d5"}, - {file = "numpy-1.26.1-cp39-cp39-win32.whl", hash = "sha256:d1d2c6b7dd618c41e202c59c1413ef9b2c8e8a15f5039e344af64195459e3104"}, - {file = "numpy-1.26.1-cp39-cp39-win_amd64.whl", hash = "sha256:59227c981d43425ca5e5c01094d59eb14e8772ce6975d4b2fc1e106a833d5ae2"}, - {file = "numpy-1.26.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:06934e1a22c54636a059215d6da99e23286424f316fddd979f5071093b648668"}, - {file = "numpy-1.26.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76ff661a867d9272cd2a99eed002470f46dbe0943a5ffd140f49be84f68ffc42"}, - {file = "numpy-1.26.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:6965888d65d2848e8768824ca8288db0a81263c1efccec881cb35a0d805fcd2f"}, - {file = "numpy-1.26.1.tar.gz", hash = "sha256:c8c6c72d4a9f831f328efb1312642a1cafafaa88981d9ab76368d50d07d93cbe"}, +python-versions = ">=3.9" +files = [ + {file = "numpy-1.26.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3703fc9258a4a122d17043e57b35e5ef1c5a5837c3db8be396c82e04c1cf9b0f"}, + {file = "numpy-1.26.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cc392fdcbd21d4be6ae1bb4475a03ce3b025cd49a9be5345d76d7585aea69440"}, + {file = "numpy-1.26.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36340109af8da8805d8851ef1d74761b3b88e81a9bd80b290bbfed61bd2b4f75"}, + {file = "numpy-1.26.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcc008217145b3d77abd3e4d5ef586e3bdfba8fe17940769f8aa09b99e856c00"}, + {file = "numpy-1.26.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:3ced40d4e9e18242f70dd02d739e44698df3dcb010d31f495ff00a31ef6014fe"}, + {file = "numpy-1.26.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b272d4cecc32c9e19911891446b72e986157e6a1809b7b56518b4f3755267523"}, + {file = "numpy-1.26.2-cp310-cp310-win32.whl", hash = "sha256:22f8fc02fdbc829e7a8c578dd8d2e15a9074b630d4da29cda483337e300e3ee9"}, + {file = "numpy-1.26.2-cp310-cp310-win_amd64.whl", hash = "sha256:26c9d33f8e8b846d5a65dd068c14e04018d05533b348d9eaeef6c1bd787f9919"}, + {file = "numpy-1.26.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b96e7b9c624ef3ae2ae0e04fa9b460f6b9f17ad8b4bec6d7756510f1f6c0c841"}, + {file = "numpy-1.26.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:aa18428111fb9a591d7a9cc1b48150097ba6a7e8299fb56bdf574df650e7d1f1"}, + {file = "numpy-1.26.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06fa1ed84aa60ea6ef9f91ba57b5ed963c3729534e6e54055fc151fad0423f0a"}, + {file = "numpy-1.26.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96ca5482c3dbdd051bcd1fce8034603d6ebfc125a7bd59f55b40d8f5d246832b"}, + {file = "numpy-1.26.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:854ab91a2906ef29dc3925a064fcd365c7b4da743f84b123002f6139bcb3f8a7"}, + {file = "numpy-1.26.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f43740ab089277d403aa07567be138fc2a89d4d9892d113b76153e0e412409f8"}, + {file = "numpy-1.26.2-cp311-cp311-win32.whl", hash = "sha256:a2bbc29fcb1771cd7b7425f98b05307776a6baf43035d3b80c4b0f29e9545186"}, + {file = "numpy-1.26.2-cp311-cp311-win_amd64.whl", hash = "sha256:2b3fca8a5b00184828d12b073af4d0fc5fdd94b1632c2477526f6bd7842d700d"}, + {file = "numpy-1.26.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a4cd6ed4a339c21f1d1b0fdf13426cb3b284555c27ac2f156dfdaaa7e16bfab0"}, + {file = "numpy-1.26.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5d5244aabd6ed7f312268b9247be47343a654ebea52a60f002dc70c769048e75"}, + {file = "numpy-1.26.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a3cdb4d9c70e6b8c0814239ead47da00934666f668426fc6e94cce869e13fd7"}, + {file = "numpy-1.26.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa317b2325f7aa0a9471663e6093c210cb2ae9c0ad824732b307d2c51983d5b6"}, + {file = "numpy-1.26.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:174a8880739c16c925799c018f3f55b8130c1f7c8e75ab0a6fa9d41cab092fd6"}, + {file = "numpy-1.26.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f79b231bf5c16b1f39c7f4875e1ded36abee1591e98742b05d8a0fb55d8a3eec"}, + {file = "numpy-1.26.2-cp312-cp312-win32.whl", hash = "sha256:4a06263321dfd3598cacb252f51e521a8cb4b6df471bb12a7ee5cbab20ea9167"}, + {file = "numpy-1.26.2-cp312-cp312-win_amd64.whl", hash = "sha256:b04f5dc6b3efdaab541f7857351aac359e6ae3c126e2edb376929bd3b7f92d7e"}, + {file = "numpy-1.26.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4eb8df4bf8d3d90d091e0146f6c28492b0be84da3e409ebef54349f71ed271ef"}, + {file = "numpy-1.26.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1a13860fdcd95de7cf58bd6f8bc5a5ef81c0b0625eb2c9a783948847abbef2c2"}, + {file = "numpy-1.26.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:64308ebc366a8ed63fd0bf426b6a9468060962f1a4339ab1074c228fa6ade8e3"}, + {file = "numpy-1.26.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baf8aab04a2c0e859da118f0b38617e5ee65d75b83795055fb66c0d5e9e9b818"}, + {file = "numpy-1.26.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d73a3abcac238250091b11caef9ad12413dab01669511779bc9b29261dd50210"}, + {file = "numpy-1.26.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:b361d369fc7e5e1714cf827b731ca32bff8d411212fccd29ad98ad622449cc36"}, + {file = "numpy-1.26.2-cp39-cp39-win32.whl", hash = "sha256:bd3f0091e845164a20bd5a326860c840fe2af79fa12e0469a12768a3ec578d80"}, + {file = "numpy-1.26.2-cp39-cp39-win_amd64.whl", hash = "sha256:2beef57fb031dcc0dc8fa4fe297a742027b954949cabb52a2a376c144e5e6060"}, + {file = "numpy-1.26.2-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:1cc3d5029a30fb5f06704ad6b23b35e11309491c999838c31f124fee32107c79"}, + {file = "numpy-1.26.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94cc3c222bb9fb5a12e334d0479b97bb2df446fbe622b470928f5284ffca3f8d"}, + {file = "numpy-1.26.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fe6b44fb8fcdf7eda4ef4461b97b3f63c466b27ab151bec2366db8b197387841"}, + {file = "numpy-1.26.2.tar.gz", hash = "sha256:f65738447676ab5777f11e6bbbdb8ce11b785e105f690bc45966574816b6d3ea"}, ] [[package]] @@ -2582,13 +2575,13 @@ files = [ [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.3.52" +version = "12.3.101" description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nvjitlink_cu12-12.3.52-py3-none-manylinux1_x86_64.whl", hash = "sha256:93db4dba8cb66fe2a351791e557208345bb9d0ace1bfb9dd05a4812f9a3ac74e"}, - {file = "nvidia_nvjitlink_cu12-12.3.52-py3-none-win_amd64.whl", hash = "sha256:9e403610da6ebceee897371a6982433ec997a9279d2320840413ce82a1d28ddc"}, + {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"}, + {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"}, ] [[package]] @@ -2661,7 +2654,7 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, ] python-dateutil = ">=2.8.2" @@ -2733,34 +2726,24 @@ testing = ["docopt", "pytest (<6.0.0)"] [[package]] name = "pathspec" -version = "0.11.2" +version = "0.12.0" description = "Utility library for gitignore style pattern matching of file paths." optional = false -python-versions = ">=3.7" -files = [ - {file = "pathspec-0.11.2-py3-none-any.whl", hash = "sha256:1d6ed233af05e679efb96b1851550ea95bbb64b7c490b0f5aa52996c11e92a20"}, - {file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"}, -] - -[[package]] -name = "pathtools" -version = "0.1.2" -description = "File system general utilities" -optional = false -python-versions = "*" +python-versions = ">=3.8" files = [ - {file = "pathtools-0.1.2.tar.gz", hash = "sha256:7c35c5421a39bb82e58018febd90e3b6e5db34c5443aaaf742b3f33d4655f1c0"}, + {file = "pathspec-0.12.0-py3-none-any.whl", hash = "sha256:f1f8a7eab698c357945c85ed79715e014612b8584faebe209dca4558e2b09513"}, + {file = "pathspec-0.12.0.tar.gz", hash = "sha256:c57e16065a97b7beb175f13c84d27cb05f7b7315741c2fbd5de541042f4ea6e1"}, ] [[package]] name = "pexpect" -version = "4.8.0" +version = "4.9.0" description = "Pexpect allows easy control of interactive console applications." optional = false python-versions = "*" files = [ - {file = "pexpect-4.8.0-py2.py3-none-any.whl", hash = "sha256:0b48a55dcb3c05f3329815901ea4fc1537514d6ba867a152b581d69ae3710937"}, - {file = "pexpect-4.8.0.tar.gz", hash = "sha256:fc65a43959d153d0114afe13997d439c22823a27cefceb5ff35c2178c6784c0c"}, + {file = "pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523"}, + {file = "pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f"}, ] [package.dependencies] @@ -2790,13 +2773,13 @@ files = [ [[package]] name = "platformdirs" -version = "3.11.0" +version = "4.1.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "platformdirs-3.11.0-py3-none-any.whl", hash = "sha256:e9d171d00af68be50e9202731309c4e658fd8bc76f55c11c7dd760d023bda68e"}, - {file = "platformdirs-3.11.0.tar.gz", hash = "sha256:cf8ee52a3afdb965072dcc652433e0c7e3e40cf5ea1477cd4b3b1d2eb75495b3"}, + {file = "platformdirs-4.1.0-py3-none-any.whl", hash = "sha256:11c8f37bcca40db96d8144522d925583bdb7a31f7b0e37e3ed4318400a8e2380"}, + {file = "platformdirs-4.1.0.tar.gz", hash = "sha256:906d548203468492d432bcb294d4bc2fff751bf84971fbb2c10918cc206ee420"}, ] [package.extras] @@ -2879,13 +2862,13 @@ six = ">=1.5.2" [[package]] name = "prometheus-client" -version = "0.17.1" +version = "0.19.0" description = "Python client for the Prometheus monitoring system." optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "prometheus_client-0.17.1-py3-none-any.whl", hash = "sha256:e537f37160f6807b8202a6fc4764cdd19bac5480ddd3e0d463c3002b34462101"}, - {file = "prometheus_client-0.17.1.tar.gz", hash = "sha256:21e674f39831ae3f8acde238afd9a27a37d0d2fb5a28ea094f0ce25d2cbf2091"}, + {file = "prometheus_client-0.19.0-py3-none-any.whl", hash = "sha256:c88b1e6ecf6b41cd8fb5731c7ae919bf66df6ec6fafa555cd6c0e16ca169ae92"}, + {file = "prometheus_client-0.19.0.tar.gz", hash = "sha256:4585b0d1223148c27a225b10dbec5ae9bc4c81a99a3fa80774fa6209935324e1"}, ] [package.extras] @@ -2893,13 +2876,13 @@ twisted = ["twisted"] [[package]] name = "prompt-toolkit" -version = "3.0.39" +version = "3.0.41" description = "Library for building powerful interactive command lines in Python" optional = false python-versions = ">=3.7.0" files = [ - {file = "prompt_toolkit-3.0.39-py3-none-any.whl", hash = "sha256:9dffbe1d8acf91e3de75f3b544e4842382fc06c6babe903ac9acb74dc6e08d88"}, - {file = "prompt_toolkit-3.0.39.tar.gz", hash = "sha256:04505ade687dc26dc4284b1ad19a83be2f2afe83e7a828ace0c72f3a1df72aac"}, + {file = "prompt_toolkit-3.0.41-py3-none-any.whl", hash = "sha256:f36fe301fafb7470e86aaf90f036eef600a3210be4decf461a5b1ca8403d3cb2"}, + {file = "prompt_toolkit-3.0.41.tar.gz", hash = "sha256:941367d97fc815548822aa26c2a269fdc4eb21e9ec05fc5d447cf09bad5d75f0"}, ] [package.dependencies] @@ -2907,24 +2890,22 @@ wcwidth = "*" [[package]] name = "protobuf" -version = "4.24.4" +version = "4.25.1" description = "" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "protobuf-4.24.4-cp310-abi3-win32.whl", hash = "sha256:ec9912d5cb6714a5710e28e592ee1093d68c5ebfeda61983b3f40331da0b1ebb"}, - {file = "protobuf-4.24.4-cp310-abi3-win_amd64.whl", hash = "sha256:1badab72aa8a3a2b812eacfede5020472e16c6b2212d737cefd685884c191085"}, - {file = "protobuf-4.24.4-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:8e61a27f362369c2f33248a0ff6896c20dcd47b5d48239cb9720134bef6082e4"}, - {file = "protobuf-4.24.4-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:bffa46ad9612e6779d0e51ae586fde768339b791a50610d85eb162daeb23661e"}, - {file = "protobuf-4.24.4-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:b493cb590960ff863743b9ff1452c413c2ee12b782f48beca77c8da3e2ffe9d9"}, - {file = "protobuf-4.24.4-cp37-cp37m-win32.whl", hash = "sha256:dbbed8a56e56cee8d9d522ce844a1379a72a70f453bde6243e3c86c30c2a3d46"}, - {file = "protobuf-4.24.4-cp37-cp37m-win_amd64.whl", hash = "sha256:6b7d2e1c753715dcfe9d284a25a52d67818dd43c4932574307daf836f0071e37"}, - {file = "protobuf-4.24.4-cp38-cp38-win32.whl", hash = "sha256:02212557a76cd99574775a81fefeba8738d0f668d6abd0c6b1d3adcc75503dbe"}, - {file = "protobuf-4.24.4-cp38-cp38-win_amd64.whl", hash = "sha256:2fa3886dfaae6b4c5ed2730d3bf47c7a38a72b3a1f0acb4d4caf68e6874b947b"}, - {file = "protobuf-4.24.4-cp39-cp39-win32.whl", hash = "sha256:b77272f3e28bb416e2071186cb39efd4abbf696d682cbb5dc731308ad37fa6dd"}, - {file = "protobuf-4.24.4-cp39-cp39-win_amd64.whl", hash = "sha256:9fee5e8aa20ef1b84123bb9232b3f4a5114d9897ed89b4b8142d81924e05d79b"}, - {file = "protobuf-4.24.4-py3-none-any.whl", hash = "sha256:80797ce7424f8c8d2f2547e2d42bfbb6c08230ce5832d6c099a37335c9c90a92"}, - {file = "protobuf-4.24.4.tar.gz", hash = "sha256:5a70731910cd9104762161719c3d883c960151eea077134458503723b60e3667"}, + {file = "protobuf-4.25.1-cp310-abi3-win32.whl", hash = "sha256:193f50a6ab78a970c9b4f148e7c750cfde64f59815e86f686c22e26b4fe01ce7"}, + {file = "protobuf-4.25.1-cp310-abi3-win_amd64.whl", hash = "sha256:3497c1af9f2526962f09329fd61a36566305e6c72da2590ae0d7d1322818843b"}, + {file = "protobuf-4.25.1-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:0bf384e75b92c42830c0a679b0cd4d6e2b36ae0cf3dbb1e1dfdda48a244f4bcd"}, + {file = "protobuf-4.25.1-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:0f881b589ff449bf0b931a711926e9ddaad3b35089cc039ce1af50b21a4ae8cb"}, + {file = "protobuf-4.25.1-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:ca37bf6a6d0046272c152eea90d2e4ef34593aaa32e8873fc14c16440f22d4b7"}, + {file = "protobuf-4.25.1-cp38-cp38-win32.whl", hash = "sha256:abc0525ae2689a8000837729eef7883b9391cd6aa7950249dcf5a4ede230d5dd"}, + {file = "protobuf-4.25.1-cp38-cp38-win_amd64.whl", hash = "sha256:1484f9e692091450e7edf418c939e15bfc8fc68856e36ce399aed6889dae8bb0"}, + {file = "protobuf-4.25.1-cp39-cp39-win32.whl", hash = "sha256:8bdbeaddaac52d15c6dce38c71b03038ef7772b977847eb6d374fc86636fa510"}, + {file = "protobuf-4.25.1-cp39-cp39-win_amd64.whl", hash = "sha256:becc576b7e6b553d22cbdf418686ee4daa443d7217999125c045ad56322dda10"}, + {file = "protobuf-4.25.1-py3-none-any.whl", hash = "sha256:a19731d5e83ae4737bb2a089605e636077ac001d18781b3cf489b9546c7c80d6"}, + {file = "protobuf-4.25.1.tar.gz", hash = "sha256:57d65074b4f5baa4ab5da1605c02be90ac20c8b40fb137d6a8df9f416b0d0ce2"}, ] [[package]] @@ -2982,58 +2963,76 @@ tests = ["pytest"] [[package]] name = "pyarrow" -version = "13.0.0" +version = "14.0.1" description = "Python library for Apache Arrow" optional = false python-versions = ">=3.8" files = [ - {file = "pyarrow-13.0.0-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:1afcc2c33f31f6fb25c92d50a86b7a9f076d38acbcb6f9e74349636109550148"}, - {file = "pyarrow-13.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:70fa38cdc66b2fc1349a082987f2b499d51d072faaa6b600f71931150de2e0e3"}, - {file = "pyarrow-13.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cd57b13a6466822498238877892a9b287b0a58c2e81e4bdb0b596dbb151cbb73"}, - {file = "pyarrow-13.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8ce69f7bf01de2e2764e14df45b8404fc6f1a5ed9871e8e08a12169f87b7a26"}, - {file = "pyarrow-13.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:588f0d2da6cf1b1680974d63be09a6530fd1bd825dc87f76e162404779a157dc"}, - {file = "pyarrow-13.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:6241afd72b628787b4abea39e238e3ff9f34165273fad306c7acf780dd850956"}, - {file = "pyarrow-13.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:fda7857e35993673fcda603c07d43889fca60a5b254052a462653f8656c64f44"}, - {file = "pyarrow-13.0.0-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:aac0ae0146a9bfa5e12d87dda89d9ef7c57a96210b899459fc2f785303dcbb67"}, - {file = "pyarrow-13.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d7759994217c86c161c6a8060509cfdf782b952163569606bb373828afdd82e8"}, - {file = "pyarrow-13.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:868a073fd0ff6468ae7d869b5fc1f54de5c4255b37f44fb890385eb68b68f95d"}, - {file = "pyarrow-13.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51be67e29f3cfcde263a113c28e96aa04362ed8229cb7c6e5f5c719003659d33"}, - {file = "pyarrow-13.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:d1b4e7176443d12610874bb84d0060bf080f000ea9ed7c84b2801df851320295"}, - {file = "pyarrow-13.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:69b6f9a089d116a82c3ed819eea8fe67dae6105f0d81eaf0fdd5e60d0c6e0944"}, - {file = "pyarrow-13.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:ab1268db81aeb241200e321e220e7cd769762f386f92f61b898352dd27e402ce"}, - {file = "pyarrow-13.0.0-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:ee7490f0f3f16a6c38f8c680949551053c8194e68de5046e6c288e396dccee80"}, - {file = "pyarrow-13.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e3ad79455c197a36eefbd90ad4aa832bece7f830a64396c15c61a0985e337287"}, - {file = "pyarrow-13.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68fcd2dc1b7d9310b29a15949cdd0cb9bc34b6de767aff979ebf546020bf0ba0"}, - {file = "pyarrow-13.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc6fd330fd574c51d10638e63c0d00ab456498fc804c9d01f2a61b9264f2c5b2"}, - {file = "pyarrow-13.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:e66442e084979a97bb66939e18f7b8709e4ac5f887e636aba29486ffbf373763"}, - {file = "pyarrow-13.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:0f6eff839a9e40e9c5610d3ff8c5bdd2f10303408312caf4c8003285d0b49565"}, - {file = "pyarrow-13.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b30a27f1cddf5c6efcb67e598d7823a1e253d743d92ac32ec1eb4b6a1417867"}, - {file = "pyarrow-13.0.0-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:09552dad5cf3de2dc0aba1c7c4b470754c69bd821f5faafc3d774bedc3b04bb7"}, - {file = "pyarrow-13.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3896ae6c205d73ad192d2fc1489cd0edfab9f12867c85b4c277af4d37383c18c"}, - {file = "pyarrow-13.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6647444b21cb5e68b593b970b2a9a07748dd74ea457c7dadaa15fd469c48ada1"}, - {file = "pyarrow-13.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47663efc9c395e31d09c6aacfa860f4473815ad6804311c5433f7085415d62a7"}, - {file = "pyarrow-13.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:b9ba6b6d34bd2563345488cf444510588ea42ad5613df3b3509f48eb80250afd"}, - {file = "pyarrow-13.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:d00d374a5625beeb448a7fa23060df79adb596074beb3ddc1838adb647b6ef09"}, - {file = "pyarrow-13.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:c51afd87c35c8331b56f796eff954b9c7f8d4b7fef5903daf4e05fcf017d23a8"}, - {file = "pyarrow-13.0.0.tar.gz", hash = "sha256:83333726e83ed44b0ac94d8d7a21bbdee4a05029c3b1e8db58a863eec8fd8a33"}, + {file = "pyarrow-14.0.1-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:96d64e5ba7dceb519a955e5eeb5c9adcfd63f73a56aea4722e2cc81364fc567a"}, + {file = "pyarrow-14.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1a8ae88c0038d1bc362a682320112ee6774f006134cd5afc291591ee4bc06505"}, + {file = "pyarrow-14.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f6f053cb66dc24091f5511e5920e45c83107f954a21032feadc7b9e3a8e7851"}, + {file = "pyarrow-14.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:906b0dc25f2be12e95975722f1e60e162437023f490dbd80d0deb7375baf3171"}, + {file = "pyarrow-14.0.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:78d4a77a46a7de9388b653af1c4ce539350726cd9af62e0831e4f2bd0c95a2f4"}, + {file = "pyarrow-14.0.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:06ca79080ef89d6529bb8e5074d4b4f6086143b2520494fcb7cf8a99079cde93"}, + {file = "pyarrow-14.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:32542164d905002c42dff896efdac79b3bdd7291b1b74aa292fac8450d0e4dcd"}, + {file = "pyarrow-14.0.1-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:c7331b4ed3401b7ee56f22c980608cf273f0380f77d0f73dd3c185f78f5a6220"}, + {file = "pyarrow-14.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:922e8b49b88da8633d6cac0e1b5a690311b6758d6f5d7c2be71acb0f1e14cd61"}, + {file = "pyarrow-14.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58c889851ca33f992ea916b48b8540735055201b177cb0dcf0596a495a667b00"}, + {file = "pyarrow-14.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30d8494870d9916bb53b2a4384948491444741cb9a38253c590e21f836b01222"}, + {file = "pyarrow-14.0.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:be28e1a07f20391bb0b15ea03dcac3aade29fc773c5eb4bee2838e9b2cdde0cb"}, + {file = "pyarrow-14.0.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:981670b4ce0110d8dcb3246410a4aabf5714db5d8ea63b15686bce1c914b1f83"}, + {file = "pyarrow-14.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:4756a2b373a28f6166c42711240643fb8bd6322467e9aacabd26b488fa41ec23"}, + {file = "pyarrow-14.0.1-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:cf87e2cec65dd5cf1aa4aba918d523ef56ef95597b545bbaad01e6433851aa10"}, + {file = "pyarrow-14.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:470ae0194fbfdfbf4a6b65b4f9e0f6e1fa0ea5b90c1ee6b65b38aecee53508c8"}, + {file = "pyarrow-14.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6263cffd0c3721c1e348062997babdf0151301f7353010c9c9a8ed47448f82ab"}, + {file = "pyarrow-14.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a8089d7e77d1455d529dbd7cff08898bbb2666ee48bc4085203af1d826a33cc"}, + {file = "pyarrow-14.0.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:fada8396bc739d958d0b81d291cfd201126ed5e7913cb73de6bc606befc30226"}, + {file = "pyarrow-14.0.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:2a145dab9ed7849fc1101bf03bcdc69913547f10513fdf70fc3ab6c0a50c7eee"}, + {file = "pyarrow-14.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:05fe7994745b634c5fb16ce5717e39a1ac1fac3e2b0795232841660aa76647cd"}, + {file = "pyarrow-14.0.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:a8eeef015ae69d104c4c3117a6011e7e3ecd1abec79dc87fd2fac6e442f666ee"}, + {file = "pyarrow-14.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3c76807540989fe8fcd02285dd15e4f2a3da0b09d27781abec3adc265ddbeba1"}, + {file = "pyarrow-14.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:450e4605e3c20e558485f9161a79280a61c55efe585d51513c014de9ae8d393f"}, + {file = "pyarrow-14.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:323cbe60210173ffd7db78bfd50b80bdd792c4c9daca8843ef3cd70b186649db"}, + {file = "pyarrow-14.0.1-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0140c7e2b740e08c5a459439d87acd26b747fc408bde0a8806096ee0baaa0c15"}, + {file = "pyarrow-14.0.1-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:e592e482edd9f1ab32f18cd6a716c45b2c0f2403dc2af782f4e9674952e6dd27"}, + {file = "pyarrow-14.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:d264ad13605b61959f2ae7c1d25b1a5b8505b112715c961418c8396433f213ad"}, + {file = "pyarrow-14.0.1-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:01e44de9749cddc486169cb632f3c99962318e9dacac7778315a110f4bf8a450"}, + {file = "pyarrow-14.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d0351fecf0e26e152542bc164c22ea2a8e8c682726fce160ce4d459ea802d69c"}, + {file = "pyarrow-14.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33c1f6110c386464fd2e5e4ea3624466055bbe681ff185fd6c9daa98f30a3f9a"}, + {file = "pyarrow-14.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11e045dfa09855b6d3e7705a37c42e2dc2c71d608fab34d3c23df2e02df9aec3"}, + {file = "pyarrow-14.0.1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:097828b55321897db0e1dbfc606e3ff8101ae5725673498cbfa7754ee0da80e4"}, + {file = "pyarrow-14.0.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:1daab52050a1c48506c029e6fa0944a7b2436334d7e44221c16f6f1b2cc9c510"}, + {file = "pyarrow-14.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:3f6d5faf4f1b0d5a7f97be987cf9e9f8cd39902611e818fe134588ee99bf0283"}, + {file = "pyarrow-14.0.1.tar.gz", hash = "sha256:b8b3f4fe8d4ec15e1ef9b599b94683c5216adaed78d5cb4c606180546d1e2ee1"}, ] [package.dependencies] numpy = ">=1.16.6" +[[package]] +name = "pyarrow-hotfix" +version = "0.6" +description = "" +optional = false +python-versions = ">=3.5" +files = [ + {file = "pyarrow_hotfix-0.6-py3-none-any.whl", hash = "sha256:dcc9ae2d220dff0083be6a9aa8e0cdee5182ad358d4931fce825c545e5c89178"}, + {file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"}, +] + [[package]] name = "pycln" -version = "2.3.0" +version = "2.4.0" description = "A formatter for finding and removing unused import statements." optional = false -python-versions = ">=3.6.2,<4" +python-versions = ">=3.7.0,<4" files = [ - {file = "pycln-2.3.0-py3-none-any.whl", hash = "sha256:d6731e17a60728b827211de2ca4bfc9b40ea1df99a12f3e0fd06a98a0c9e6caa"}, - {file = "pycln-2.3.0.tar.gz", hash = "sha256:8759b36753234c8f95895a31dde329479ffed2218f49d1a1c77c7edccc02e09b"}, + {file = "pycln-2.4.0-py3-none-any.whl", hash = "sha256:d1bf648df17077306100815d255d45430035b36f66bac635df04a323c61ba126"}, + {file = "pycln-2.4.0.tar.gz", hash = "sha256:1f3eefb7be18a9ee06c3bdd0ba2e91218cd39317e20130325f107e96eb84b9f6"}, ] [package.dependencies] -libcst = {version = ">=0.3.10", markers = "python_version >= \"3.7\""} +libcst = ">=0.3.10" pathspec = ">=0.9.0" pyyaml = ">=5.3.1" tomlkit = ">=0.11.1" @@ -3052,17 +3051,18 @@ files = [ [[package]] name = "pygments" -version = "2.16.1" +version = "2.17.2" description = "Pygments is a syntax highlighting package written in Python." optional = false python-versions = ">=3.7" files = [ - {file = "Pygments-2.16.1-py3-none-any.whl", hash = "sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692"}, - {file = "Pygments-2.16.1.tar.gz", hash = "sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29"}, + {file = "pygments-2.17.2-py3-none-any.whl", hash = "sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c"}, + {file = "pygments-2.17.2.tar.gz", hash = "sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367"}, ] [package.extras] plugins = ["importlib-metadata"] +windows-terminal = ["colorama (>=0.4.6)"] [[package]] name = "pytest" @@ -3209,6 +3209,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -3216,8 +3217,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -3234,6 +3242,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -3241,6 +3250,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -3248,104 +3258,104 @@ files = [ [[package]] name = "pyzmq" -version = "25.1.1" +version = "25.1.2" description = "Python bindings for 0MQ" optional = false python-versions = ">=3.6" files = [ - {file = "pyzmq-25.1.1-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:381469297409c5adf9a0e884c5eb5186ed33137badcbbb0560b86e910a2f1e76"}, - {file = "pyzmq-25.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:955215ed0604dac5b01907424dfa28b40f2b2292d6493445dd34d0dfa72586a8"}, - {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:985bbb1316192b98f32e25e7b9958088431d853ac63aca1d2c236f40afb17c83"}, - {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:afea96f64efa98df4da6958bae37f1cbea7932c35878b185e5982821bc883369"}, - {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76705c9325d72a81155bb6ab48d4312e0032bf045fb0754889133200f7a0d849"}, - {file = "pyzmq-25.1.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:77a41c26205d2353a4c94d02be51d6cbdf63c06fbc1295ea57dad7e2d3381b71"}, - {file = "pyzmq-25.1.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:12720a53e61c3b99d87262294e2b375c915fea93c31fc2336898c26d7aed34cd"}, - {file = "pyzmq-25.1.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:57459b68e5cd85b0be8184382cefd91959cafe79ae019e6b1ae6e2ba8a12cda7"}, - {file = "pyzmq-25.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:292fe3fc5ad4a75bc8df0dfaee7d0babe8b1f4ceb596437213821f761b4589f9"}, - {file = "pyzmq-25.1.1-cp310-cp310-win32.whl", hash = "sha256:35b5ab8c28978fbbb86ea54958cd89f5176ce747c1fb3d87356cf698048a7790"}, - {file = "pyzmq-25.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:11baebdd5fc5b475d484195e49bae2dc64b94a5208f7c89954e9e354fc609d8f"}, - {file = "pyzmq-25.1.1-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:d20a0ddb3e989e8807d83225a27e5c2eb2260eaa851532086e9e0fa0d5287d83"}, - {file = "pyzmq-25.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e1c1be77bc5fb77d923850f82e55a928f8638f64a61f00ff18a67c7404faf008"}, - {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d89528b4943d27029a2818f847c10c2cecc79fa9590f3cb1860459a5be7933eb"}, - {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:90f26dc6d5f241ba358bef79be9ce06de58d477ca8485e3291675436d3827cf8"}, - {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2b92812bd214018e50b6380ea3ac0c8bb01ac07fcc14c5f86a5bb25e74026e9"}, - {file = "pyzmq-25.1.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:2f957ce63d13c28730f7fd6b72333814221c84ca2421298f66e5143f81c9f91f"}, - {file = "pyzmq-25.1.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:047a640f5c9c6ade7b1cc6680a0e28c9dd5a0825135acbd3569cc96ea00b2505"}, - {file = "pyzmq-25.1.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7f7e58effd14b641c5e4dec8c7dab02fb67a13df90329e61c869b9cc607ef752"}, - {file = "pyzmq-25.1.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c2910967e6ab16bf6fbeb1f771c89a7050947221ae12a5b0b60f3bca2ee19bca"}, - {file = "pyzmq-25.1.1-cp311-cp311-win32.whl", hash = "sha256:76c1c8efb3ca3a1818b837aea423ff8a07bbf7aafe9f2f6582b61a0458b1a329"}, - {file = "pyzmq-25.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:44e58a0554b21fc662f2712814a746635ed668d0fbc98b7cb9d74cb798d202e6"}, - {file = "pyzmq-25.1.1-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:e1ffa1c924e8c72778b9ccd386a7067cddf626884fd8277f503c48bb5f51c762"}, - {file = "pyzmq-25.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1af379b33ef33757224da93e9da62e6471cf4a66d10078cf32bae8127d3d0d4a"}, - {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cff084c6933680d1f8b2f3b4ff5bbb88538a4aac00d199ac13f49d0698727ecb"}, - {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2400a94f7dd9cb20cd012951a0cbf8249e3d554c63a9c0cdfd5cbb6c01d2dec"}, - {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d81f1ddae3858b8299d1da72dd7d19dd36aab654c19671aa8a7e7fb02f6638a"}, - {file = "pyzmq-25.1.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:255ca2b219f9e5a3a9ef3081512e1358bd4760ce77828e1028b818ff5610b87b"}, - {file = "pyzmq-25.1.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a882ac0a351288dd18ecae3326b8a49d10c61a68b01419f3a0b9a306190baf69"}, - {file = "pyzmq-25.1.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:724c292bb26365659fc434e9567b3f1adbdb5e8d640c936ed901f49e03e5d32e"}, - {file = "pyzmq-25.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ca1ed0bb2d850aa8471387882247c68f1e62a4af0ce9c8a1dbe0d2bf69e41fb"}, - {file = "pyzmq-25.1.1-cp312-cp312-win32.whl", hash = "sha256:b3451108ab861040754fa5208bca4a5496c65875710f76789a9ad27c801a0075"}, - {file = "pyzmq-25.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:eadbefd5e92ef8a345f0525b5cfd01cf4e4cc651a2cffb8f23c0dd184975d787"}, - {file = "pyzmq-25.1.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:db0b2af416ba735c6304c47f75d348f498b92952f5e3e8bff449336d2728795d"}, - {file = "pyzmq-25.1.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c7c133e93b405eb0d36fa430c94185bdd13c36204a8635470cccc200723c13bb"}, - {file = "pyzmq-25.1.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:273bc3959bcbff3f48606b28229b4721716598d76b5aaea2b4a9d0ab454ec062"}, - {file = "pyzmq-25.1.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:cbc8df5c6a88ba5ae385d8930da02201165408dde8d8322072e3e5ddd4f68e22"}, - {file = "pyzmq-25.1.1-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:18d43df3f2302d836f2a56f17e5663e398416e9dd74b205b179065e61f1a6edf"}, - {file = "pyzmq-25.1.1-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:73461eed88a88c866656e08f89299720a38cb4e9d34ae6bf5df6f71102570f2e"}, - {file = "pyzmq-25.1.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:34c850ce7976d19ebe7b9d4b9bb8c9dfc7aac336c0958e2651b88cbd46682123"}, - {file = "pyzmq-25.1.1-cp36-cp36m-win32.whl", hash = "sha256:d2045d6d9439a0078f2a34b57c7b18c4a6aef0bee37f22e4ec9f32456c852c71"}, - {file = "pyzmq-25.1.1-cp36-cp36m-win_amd64.whl", hash = "sha256:458dea649f2f02a0b244ae6aef8dc29325a2810aa26b07af8374dc2a9faf57e3"}, - {file = "pyzmq-25.1.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7cff25c5b315e63b07a36f0c2bab32c58eafbe57d0dce61b614ef4c76058c115"}, - {file = "pyzmq-25.1.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1579413ae492b05de5a6174574f8c44c2b9b122a42015c5292afa4be2507f28"}, - {file = "pyzmq-25.1.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3d0a409d3b28607cc427aa5c30a6f1e4452cc44e311f843e05edb28ab5e36da0"}, - {file = "pyzmq-25.1.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:21eb4e609a154a57c520e3d5bfa0d97e49b6872ea057b7c85257b11e78068222"}, - {file = "pyzmq-25.1.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:034239843541ef7a1aee0c7b2cb7f6aafffb005ede965ae9cbd49d5ff4ff73cf"}, - {file = "pyzmq-25.1.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f8115e303280ba09f3898194791a153862cbf9eef722ad8f7f741987ee2a97c7"}, - {file = "pyzmq-25.1.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:1a5d26fe8f32f137e784f768143728438877d69a586ddeaad898558dc971a5ae"}, - {file = "pyzmq-25.1.1-cp37-cp37m-win32.whl", hash = "sha256:f32260e556a983bc5c7ed588d04c942c9a8f9c2e99213fec11a031e316874c7e"}, - {file = "pyzmq-25.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:abf34e43c531bbb510ae7e8f5b2b1f2a8ab93219510e2b287a944432fad135f3"}, - {file = "pyzmq-25.1.1-cp38-cp38-macosx_10_15_universal2.whl", hash = "sha256:87e34f31ca8f168c56d6fbf99692cc8d3b445abb5bfd08c229ae992d7547a92a"}, - {file = "pyzmq-25.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c9c6c9b2c2f80747a98f34ef491c4d7b1a8d4853937bb1492774992a120f475d"}, - {file = "pyzmq-25.1.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5619f3f5a4db5dbb572b095ea3cb5cc035335159d9da950830c9c4db2fbb6995"}, - {file = "pyzmq-25.1.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5a34d2395073ef862b4032343cf0c32a712f3ab49d7ec4f42c9661e0294d106f"}, - {file = "pyzmq-25.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25f0e6b78220aba09815cd1f3a32b9c7cb3e02cb846d1cfc526b6595f6046618"}, - {file = "pyzmq-25.1.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:3669cf8ee3520c2f13b2e0351c41fea919852b220988d2049249db10046a7afb"}, - {file = "pyzmq-25.1.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:2d163a18819277e49911f7461567bda923461c50b19d169a062536fffe7cd9d2"}, - {file = "pyzmq-25.1.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:df27ffddff4190667d40de7beba4a950b5ce78fe28a7dcc41d6f8a700a80a3c0"}, - {file = "pyzmq-25.1.1-cp38-cp38-win32.whl", hash = "sha256:a382372898a07479bd34bda781008e4a954ed8750f17891e794521c3e21c2e1c"}, - {file = "pyzmq-25.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:52533489f28d62eb1258a965f2aba28a82aa747202c8fa5a1c7a43b5db0e85c1"}, - {file = "pyzmq-25.1.1-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:03b3f49b57264909aacd0741892f2aecf2f51fb053e7d8ac6767f6c700832f45"}, - {file = "pyzmq-25.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:330f9e188d0d89080cde66dc7470f57d1926ff2fb5576227f14d5be7ab30b9fa"}, - {file = "pyzmq-25.1.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:2ca57a5be0389f2a65e6d3bb2962a971688cbdd30b4c0bd188c99e39c234f414"}, - {file = "pyzmq-25.1.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d457aed310f2670f59cc5b57dcfced452aeeed77f9da2b9763616bd57e4dbaae"}, - {file = "pyzmq-25.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c56d748ea50215abef7030c72b60dd723ed5b5c7e65e7bc2504e77843631c1a6"}, - {file = "pyzmq-25.1.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8f03d3f0d01cb5a018debeb412441996a517b11c5c17ab2001aa0597c6d6882c"}, - {file = "pyzmq-25.1.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:820c4a08195a681252f46926de10e29b6bbf3e17b30037bd4250d72dd3ddaab8"}, - {file = "pyzmq-25.1.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:17ef5f01d25b67ca8f98120d5fa1d21efe9611604e8eb03a5147360f517dd1e2"}, - {file = "pyzmq-25.1.1-cp39-cp39-win32.whl", hash = "sha256:04ccbed567171579ec2cebb9c8a3e30801723c575601f9a990ab25bcac6b51e2"}, - {file = "pyzmq-25.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:e61f091c3ba0c3578411ef505992d356a812fb200643eab27f4f70eed34a29ef"}, - {file = "pyzmq-25.1.1-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ade6d25bb29c4555d718ac6d1443a7386595528c33d6b133b258f65f963bb0f6"}, - {file = "pyzmq-25.1.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0c95ddd4f6e9fca4e9e3afaa4f9df8552f0ba5d1004e89ef0a68e1f1f9807c7"}, - {file = "pyzmq-25.1.1-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48e466162a24daf86f6b5ca72444d2bf39a5e58da5f96370078be67c67adc978"}, - {file = "pyzmq-25.1.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abc719161780932c4e11aaebb203be3d6acc6b38d2f26c0f523b5b59d2fc1996"}, - {file = "pyzmq-25.1.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:1ccf825981640b8c34ae54231b7ed00271822ea1c6d8ba1090ebd4943759abf5"}, - {file = "pyzmq-25.1.1-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c2f20ce161ebdb0091a10c9ca0372e023ce24980d0e1f810f519da6f79c60800"}, - {file = "pyzmq-25.1.1-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:deee9ca4727f53464daf089536e68b13e6104e84a37820a88b0a057b97bba2d2"}, - {file = "pyzmq-25.1.1-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:aa8d6cdc8b8aa19ceb319aaa2b660cdaccc533ec477eeb1309e2a291eaacc43a"}, - {file = "pyzmq-25.1.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:019e59ef5c5256a2c7378f2fb8560fc2a9ff1d315755204295b2eab96b254d0a"}, - {file = "pyzmq-25.1.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:b9af3757495c1ee3b5c4e945c1df7be95562277c6e5bccc20a39aec50f826cd0"}, - {file = "pyzmq-25.1.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:548d6482dc8aadbe7e79d1b5806585c8120bafa1ef841167bc9090522b610fa6"}, - {file = "pyzmq-25.1.1-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:057e824b2aae50accc0f9a0570998adc021b372478a921506fddd6c02e60308e"}, - {file = "pyzmq-25.1.1-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2243700cc5548cff20963f0ca92d3e5e436394375ab8a354bbea2b12911b20b0"}, - {file = "pyzmq-25.1.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79986f3b4af059777111409ee517da24a529bdbd46da578b33f25580adcff728"}, - {file = "pyzmq-25.1.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:11d58723d44d6ed4dd677c5615b2ffb19d5c426636345567d6af82be4dff8a55"}, - {file = "pyzmq-25.1.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:49d238cf4b69652257db66d0c623cd3e09b5d2e9576b56bc067a396133a00d4a"}, - {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fedbdc753827cf014c01dbbee9c3be17e5a208dcd1bf8641ce2cd29580d1f0d4"}, - {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bc16ac425cc927d0a57d242589f87ee093884ea4804c05a13834d07c20db203c"}, - {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11c1d2aed9079c6b0c9550a7257a836b4a637feb334904610f06d70eb44c56d2"}, - {file = "pyzmq-25.1.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e8a701123029cc240cea61dd2d16ad57cab4691804143ce80ecd9286b464d180"}, - {file = "pyzmq-25.1.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:61706a6b6c24bdece85ff177fec393545a3191eeda35b07aaa1458a027ad1304"}, - {file = "pyzmq-25.1.1.tar.gz", hash = "sha256:259c22485b71abacdfa8bf79720cd7bcf4b9d128b30ea554f01ae71fdbfdaa23"}, + {file = "pyzmq-25.1.2-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:e624c789359f1a16f83f35e2c705d07663ff2b4d4479bad35621178d8f0f6ea4"}, + {file = "pyzmq-25.1.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:49151b0efece79f6a79d41a461d78535356136ee70084a1c22532fc6383f4ad0"}, + {file = "pyzmq-25.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9a5f194cf730f2b24d6af1f833c14c10f41023da46a7f736f48b6d35061e76e"}, + {file = "pyzmq-25.1.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:faf79a302f834d9e8304fafdc11d0d042266667ac45209afa57e5efc998e3872"}, + {file = "pyzmq-25.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f51a7b4ead28d3fca8dda53216314a553b0f7a91ee8fc46a72b402a78c3e43d"}, + {file = "pyzmq-25.1.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:0ddd6d71d4ef17ba5a87becf7ddf01b371eaba553c603477679ae817a8d84d75"}, + {file = "pyzmq-25.1.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:246747b88917e4867e2367b005fc8eefbb4a54b7db363d6c92f89d69abfff4b6"}, + {file = "pyzmq-25.1.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:00c48ae2fd81e2a50c3485de1b9d5c7c57cd85dc8ec55683eac16846e57ac979"}, + {file = "pyzmq-25.1.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5a68d491fc20762b630e5db2191dd07ff89834086740f70e978bb2ef2668be08"}, + {file = "pyzmq-25.1.2-cp310-cp310-win32.whl", hash = "sha256:09dfe949e83087da88c4a76767df04b22304a682d6154de2c572625c62ad6886"}, + {file = "pyzmq-25.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:fa99973d2ed20417744fca0073390ad65ce225b546febb0580358e36aa90dba6"}, + {file = "pyzmq-25.1.2-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:82544e0e2d0c1811482d37eef297020a040c32e0687c1f6fc23a75b75db8062c"}, + {file = "pyzmq-25.1.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:01171fc48542348cd1a360a4b6c3e7d8f46cdcf53a8d40f84db6707a6768acc1"}, + {file = "pyzmq-25.1.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc69c96735ab501419c432110016329bf0dea8898ce16fab97c6d9106dc0b348"}, + {file = "pyzmq-25.1.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3e124e6b1dd3dfbeb695435dff0e383256655bb18082e094a8dd1f6293114642"}, + {file = "pyzmq-25.1.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7598d2ba821caa37a0f9d54c25164a4fa351ce019d64d0b44b45540950458840"}, + {file = "pyzmq-25.1.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:d1299d7e964c13607efd148ca1f07dcbf27c3ab9e125d1d0ae1d580a1682399d"}, + {file = "pyzmq-25.1.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4e6f689880d5ad87918430957297c975203a082d9a036cc426648fcbedae769b"}, + {file = "pyzmq-25.1.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cc69949484171cc961e6ecd4a8911b9ce7a0d1f738fcae717177c231bf77437b"}, + {file = "pyzmq-25.1.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9880078f683466b7f567b8624bfc16cad65077be046b6e8abb53bed4eeb82dd3"}, + {file = "pyzmq-25.1.2-cp311-cp311-win32.whl", hash = "sha256:4e5837af3e5aaa99a091302df5ee001149baff06ad22b722d34e30df5f0d9097"}, + {file = "pyzmq-25.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:25c2dbb97d38b5ac9fd15586e048ec5eb1e38f3d47fe7d92167b0c77bb3584e9"}, + {file = "pyzmq-25.1.2-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:11e70516688190e9c2db14fcf93c04192b02d457b582a1f6190b154691b4c93a"}, + {file = "pyzmq-25.1.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:313c3794d650d1fccaaab2df942af9f2c01d6217c846177cfcbc693c7410839e"}, + {file = "pyzmq-25.1.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b3cbba2f47062b85fe0ef9de5b987612140a9ba3a9c6d2543c6dec9f7c2ab27"}, + {file = "pyzmq-25.1.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fc31baa0c32a2ca660784d5af3b9487e13b61b3032cb01a115fce6588e1bed30"}, + {file = "pyzmq-25.1.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02c9087b109070c5ab0b383079fa1b5f797f8d43e9a66c07a4b8b8bdecfd88ee"}, + {file = "pyzmq-25.1.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:f8429b17cbb746c3e043cb986328da023657e79d5ed258b711c06a70c2ea7537"}, + {file = "pyzmq-25.1.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:5074adeacede5f810b7ef39607ee59d94e948b4fd954495bdb072f8c54558181"}, + {file = "pyzmq-25.1.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7ae8f354b895cbd85212da245f1a5ad8159e7840e37d78b476bb4f4c3f32a9fe"}, + {file = "pyzmq-25.1.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b264bf2cc96b5bc43ce0e852be995e400376bd87ceb363822e2cb1964fcdc737"}, + {file = "pyzmq-25.1.2-cp312-cp312-win32.whl", hash = "sha256:02bbc1a87b76e04fd780b45e7f695471ae6de747769e540da909173d50ff8e2d"}, + {file = "pyzmq-25.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:ced111c2e81506abd1dc142e6cd7b68dd53747b3b7ae5edbea4578c5eeff96b7"}, + {file = "pyzmq-25.1.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:7b6d09a8962a91151f0976008eb7b29b433a560fde056ec7a3db9ec8f1075438"}, + {file = "pyzmq-25.1.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:967668420f36878a3c9ecb5ab33c9d0ff8d054f9c0233d995a6d25b0e95e1b6b"}, + {file = "pyzmq-25.1.2-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5edac3f57c7ddaacdb4d40f6ef2f9e299471fc38d112f4bc6d60ab9365445fb0"}, + {file = "pyzmq-25.1.2-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:0dabfb10ef897f3b7e101cacba1437bd3a5032ee667b7ead32bbcdd1a8422fe7"}, + {file = "pyzmq-25.1.2-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:2c6441e0398c2baacfe5ba30c937d274cfc2dc5b55e82e3749e333aabffde561"}, + {file = "pyzmq-25.1.2-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:16b726c1f6c2e7625706549f9dbe9b06004dfbec30dbed4bf50cbdfc73e5b32a"}, + {file = "pyzmq-25.1.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:a86c2dd76ef71a773e70551a07318b8e52379f58dafa7ae1e0a4be78efd1ff16"}, + {file = "pyzmq-25.1.2-cp36-cp36m-win32.whl", hash = "sha256:359f7f74b5d3c65dae137f33eb2bcfa7ad9ebefd1cab85c935f063f1dbb245cc"}, + {file = "pyzmq-25.1.2-cp36-cp36m-win_amd64.whl", hash = "sha256:55875492f820d0eb3417b51d96fea549cde77893ae3790fd25491c5754ea2f68"}, + {file = "pyzmq-25.1.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b8c8a419dfb02e91b453615c69568442e897aaf77561ee0064d789705ff37a92"}, + {file = "pyzmq-25.1.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8807c87fa893527ae8a524c15fc505d9950d5e856f03dae5921b5e9aa3b8783b"}, + {file = "pyzmq-25.1.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5e319ed7d6b8f5fad9b76daa0a68497bc6f129858ad956331a5835785761e003"}, + {file = "pyzmq-25.1.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:3c53687dde4d9d473c587ae80cc328e5b102b517447456184b485587ebd18b62"}, + {file = "pyzmq-25.1.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:9add2e5b33d2cd765ad96d5eb734a5e795a0755f7fc49aa04f76d7ddda73fd70"}, + {file = "pyzmq-25.1.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:e690145a8c0c273c28d3b89d6fb32c45e0d9605b2293c10e650265bf5c11cfec"}, + {file = "pyzmq-25.1.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:00a06faa7165634f0cac1abb27e54d7a0b3b44eb9994530b8ec73cf52e15353b"}, + {file = "pyzmq-25.1.2-cp37-cp37m-win32.whl", hash = "sha256:0f97bc2f1f13cb16905a5f3e1fbdf100e712d841482b2237484360f8bc4cb3d7"}, + {file = "pyzmq-25.1.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6cc0020b74b2e410287e5942e1e10886ff81ac77789eb20bec13f7ae681f0fdd"}, + {file = "pyzmq-25.1.2-cp38-cp38-macosx_10_15_universal2.whl", hash = "sha256:bef02cfcbded83473bdd86dd8d3729cd82b2e569b75844fb4ea08fee3c26ae41"}, + {file = "pyzmq-25.1.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e10a4b5a4b1192d74853cc71a5e9fd022594573926c2a3a4802020360aa719d8"}, + {file = "pyzmq-25.1.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8c5f80e578427d4695adac6fdf4370c14a2feafdc8cb35549c219b90652536ae"}, + {file = "pyzmq-25.1.2-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:5dde6751e857910c1339890f3524de74007958557593b9e7e8c5f01cd919f8a7"}, + {file = "pyzmq-25.1.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea1608dd169da230a0ad602d5b1ebd39807ac96cae1845c3ceed39af08a5c6df"}, + {file = "pyzmq-25.1.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:0f513130c4c361201da9bc69df25a086487250e16b5571ead521b31ff6b02220"}, + {file = "pyzmq-25.1.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:019744b99da30330798bb37df33549d59d380c78e516e3bab9c9b84f87a9592f"}, + {file = "pyzmq-25.1.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2e2713ef44be5d52dd8b8e2023d706bf66cb22072e97fc71b168e01d25192755"}, + {file = "pyzmq-25.1.2-cp38-cp38-win32.whl", hash = "sha256:07cd61a20a535524906595e09344505a9bd46f1da7a07e504b315d41cd42eb07"}, + {file = "pyzmq-25.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:eb7e49a17fb8c77d3119d41a4523e432eb0c6932187c37deb6fbb00cc3028088"}, + {file = "pyzmq-25.1.2-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:94504ff66f278ab4b7e03e4cba7e7e400cb73bfa9d3d71f58d8972a8dc67e7a6"}, + {file = "pyzmq-25.1.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6dd0d50bbf9dca1d0bdea219ae6b40f713a3fb477c06ca3714f208fd69e16fd8"}, + {file = "pyzmq-25.1.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:004ff469d21e86f0ef0369717351073e0e577428e514c47c8480770d5e24a565"}, + {file = "pyzmq-25.1.2-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c0b5ca88a8928147b7b1e2dfa09f3b6c256bc1135a1338536cbc9ea13d3b7add"}, + {file = "pyzmq-25.1.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c9a79f1d2495b167119d02be7448bfba57fad2a4207c4f68abc0bab4b92925b"}, + {file = "pyzmq-25.1.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:518efd91c3d8ac9f9b4f7dd0e2b7b8bf1a4fe82a308009016b07eaa48681af82"}, + {file = "pyzmq-25.1.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:1ec23bd7b3a893ae676d0e54ad47d18064e6c5ae1fadc2f195143fb27373f7f6"}, + {file = "pyzmq-25.1.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db36c27baed588a5a8346b971477b718fdc66cf5b80cbfbd914b4d6d355e44e2"}, + {file = "pyzmq-25.1.2-cp39-cp39-win32.whl", hash = "sha256:39b1067f13aba39d794a24761e385e2eddc26295826530a8c7b6c6c341584289"}, + {file = "pyzmq-25.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:8e9f3fabc445d0ce320ea2c59a75fe3ea591fdbdeebec5db6de530dd4b09412e"}, + {file = "pyzmq-25.1.2-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a8c1d566344aee826b74e472e16edae0a02e2a044f14f7c24e123002dcff1c05"}, + {file = "pyzmq-25.1.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:759cfd391a0996345ba94b6a5110fca9c557ad4166d86a6e81ea526c376a01e8"}, + {file = "pyzmq-25.1.2-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c61e346ac34b74028ede1c6b4bcecf649d69b707b3ff9dc0fab453821b04d1e"}, + {file = "pyzmq-25.1.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4cb8fc1f8d69b411b8ec0b5f1ffbcaf14c1db95b6bccea21d83610987435f1a4"}, + {file = "pyzmq-25.1.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:3c00c9b7d1ca8165c610437ca0c92e7b5607b2f9076f4eb4b095c85d6e680a1d"}, + {file = "pyzmq-25.1.2-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:df0c7a16ebb94452d2909b9a7b3337940e9a87a824c4fc1c7c36bb4404cb0cde"}, + {file = "pyzmq-25.1.2-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:45999e7f7ed5c390f2e87ece7f6c56bf979fb213550229e711e45ecc7d42ccb8"}, + {file = "pyzmq-25.1.2-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ac170e9e048b40c605358667aca3d94e98f604a18c44bdb4c102e67070f3ac9b"}, + {file = "pyzmq-25.1.2-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1b604734bec94f05f81b360a272fc824334267426ae9905ff32dc2be433ab96"}, + {file = "pyzmq-25.1.2-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:a793ac733e3d895d96f865f1806f160696422554e46d30105807fdc9841b9f7d"}, + {file = "pyzmq-25.1.2-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0806175f2ae5ad4b835ecd87f5f85583316b69f17e97786f7443baaf54b9bb98"}, + {file = "pyzmq-25.1.2-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:ef12e259e7bc317c7597d4f6ef59b97b913e162d83b421dd0db3d6410f17a244"}, + {file = "pyzmq-25.1.2-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ea253b368eb41116011add00f8d5726762320b1bda892f744c91997b65754d73"}, + {file = "pyzmq-25.1.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b9b1f2ad6498445a941d9a4fee096d387fee436e45cc660e72e768d3d8ee611"}, + {file = "pyzmq-25.1.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:8b14c75979ce932c53b79976a395cb2a8cd3aaf14aef75e8c2cb55a330b9b49d"}, + {file = "pyzmq-25.1.2-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:889370d5174a741a62566c003ee8ddba4b04c3f09a97b8000092b7ca83ec9c49"}, + {file = "pyzmq-25.1.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9a18fff090441a40ffda8a7f4f18f03dc56ae73f148f1832e109f9bffa85df15"}, + {file = "pyzmq-25.1.2-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:99a6b36f95c98839ad98f8c553d8507644c880cf1e0a57fe5e3a3f3969040882"}, + {file = "pyzmq-25.1.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4345c9a27f4310afbb9c01750e9461ff33d6fb74cd2456b107525bbeebcb5be3"}, + {file = "pyzmq-25.1.2-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:3516e0b6224cf6e43e341d56da15fd33bdc37fa0c06af4f029f7d7dfceceabbc"}, + {file = "pyzmq-25.1.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:146b9b1f29ead41255387fb07be56dc29639262c0f7344f570eecdcd8d683314"}, + {file = "pyzmq-25.1.2.tar.gz", hash = "sha256:93f1aa311e8bb912e34f004cf186407a4e90eec4f0ecc0efd26056bf7eda0226"}, ] [package.dependencies] @@ -3353,18 +3363,17 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} [[package]] name = "qtconsole" -version = "5.4.4" +version = "5.5.1" description = "Jupyter Qt console" optional = false -python-versions = ">= 3.7" +python-versions = ">= 3.8" files = [ - {file = "qtconsole-5.4.4-py3-none-any.whl", hash = "sha256:a3b69b868e041c2c698bdc75b0602f42e130ffb256d6efa48f9aa756c97672aa"}, - {file = "qtconsole-5.4.4.tar.gz", hash = "sha256:b7ffb53d74f23cee29f4cdb55dd6fabc8ec312d94f3c46ba38e1dde458693dfb"}, + {file = "qtconsole-5.5.1-py3-none-any.whl", hash = "sha256:8c75fa3e9b4ed884880ff7cea90a1b67451219279ec33deaee1d59e3df1a5d2b"}, + {file = "qtconsole-5.5.1.tar.gz", hash = "sha256:a0e806c6951db9490628e4df80caec9669b65149c7ba40f9bf033c025a5b56bc"}, ] [package.dependencies] ipykernel = ">=4.1" -ipython-genutils = "*" jupyter-client = ">=4.1" jupyter-core = "*" packaging = "*" @@ -3396,13 +3405,13 @@ test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"] [[package]] name = "referencing" -version = "0.30.2" +version = "0.32.0" description = "JSON Referencing + Python" optional = false python-versions = ">=3.8" files = [ - {file = "referencing-0.30.2-py3-none-any.whl", hash = "sha256:449b6669b6121a9e96a7f9e410b245d471e8d48964c67113ce9afe50c8dd7bdf"}, - {file = "referencing-0.30.2.tar.gz", hash = "sha256:794ad8003c65938edcdbc027f1933215e0d0ccc0291e3ce20a4d87432b59efc0"}, + {file = "referencing-0.32.0-py3-none-any.whl", hash = "sha256:bdcd3efb936f82ff86f993093f6da7435c7de69a3b3a5a06678a6050184bee99"}, + {file = "referencing-0.32.0.tar.gz", hash = "sha256:689e64fe121843dcfd57b71933318ef1f91188ffb45367332700a86ac8fd6161"}, ] [package.dependencies] @@ -3554,13 +3563,13 @@ files = [ [[package]] name = "rich" -version = "13.6.0" +version = "13.7.0" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" optional = false python-versions = ">=3.7.0" files = [ - {file = "rich-13.6.0-py3-none-any.whl", hash = "sha256:2b38e2fe9ca72c9a00170a1a2d20c63c790d0e10ef1fe35eba76e1e7b1d7d245"}, - {file = "rich-13.6.0.tar.gz", hash = "sha256:5c14d22737e6d5084ef4771b62d5d4363165b403455a30a1c8ca39dc7b644bef"}, + {file = "rich-13.7.0-py3-none-any.whl", hash = "sha256:6da14c108c4866ee9520bbffa71f6fe3962e193b7da68720583850cd4548e235"}, + {file = "rich-13.7.0.tar.gz", hash = "sha256:5cb5123b5cf9ee70584244246816e9114227e0b98ad9176eede6ad54bf5403fa"}, ] [package.dependencies] @@ -3573,217 +3582,217 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] [[package]] name = "rpds-py" -version = "0.10.6" +version = "0.13.2" description = "Python bindings to Rust's persistent data structures (rpds)" optional = false python-versions = ">=3.8" files = [ - {file = "rpds_py-0.10.6-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:6bdc11f9623870d75692cc33c59804b5a18d7b8a4b79ef0b00b773a27397d1f6"}, - {file = "rpds_py-0.10.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:26857f0f44f0e791f4a266595a7a09d21f6b589580ee0585f330aaccccb836e3"}, - {file = "rpds_py-0.10.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7f5e15c953ace2e8dde9824bdab4bec50adb91a5663df08d7d994240ae6fa31"}, - {file = "rpds_py-0.10.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:61fa268da6e2e1cd350739bb61011121fa550aa2545762e3dc02ea177ee4de35"}, - {file = "rpds_py-0.10.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c48f3fbc3e92c7dd6681a258d22f23adc2eb183c8cb1557d2fcc5a024e80b094"}, - {file = "rpds_py-0.10.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c0503c5b681566e8b722fe8c4c47cce5c7a51f6935d5c7012c4aefe952a35eed"}, - {file = "rpds_py-0.10.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:734c41f9f57cc28658d98270d3436dba65bed0cfc730d115b290e970150c540d"}, - {file = "rpds_py-0.10.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a5d7ed104d158c0042a6a73799cf0eb576dfd5fc1ace9c47996e52320c37cb7c"}, - {file = "rpds_py-0.10.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e3df0bc35e746cce42579826b89579d13fd27c3d5319a6afca9893a9b784ff1b"}, - {file = "rpds_py-0.10.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:73e0a78a9b843b8c2128028864901f55190401ba38aae685350cf69b98d9f7c9"}, - {file = "rpds_py-0.10.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5ed505ec6305abd2c2c9586a7b04fbd4baf42d4d684a9c12ec6110deefe2a063"}, - {file = "rpds_py-0.10.6-cp310-none-win32.whl", hash = "sha256:d97dd44683802000277bbf142fd9f6b271746b4846d0acaf0cefa6b2eaf2a7ad"}, - {file = "rpds_py-0.10.6-cp310-none-win_amd64.whl", hash = "sha256:b455492cab07107bfe8711e20cd920cc96003e0da3c1f91297235b1603d2aca7"}, - {file = "rpds_py-0.10.6-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:e8cdd52744f680346ff8c1ecdad5f4d11117e1724d4f4e1874f3a67598821069"}, - {file = "rpds_py-0.10.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:66414dafe4326bca200e165c2e789976cab2587ec71beb80f59f4796b786a238"}, - {file = "rpds_py-0.10.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc435d059f926fdc5b05822b1be4ff2a3a040f3ae0a7bbbe672babb468944722"}, - {file = "rpds_py-0.10.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8e7f2219cb72474571974d29a191714d822e58be1eb171f229732bc6fdedf0ac"}, - {file = "rpds_py-0.10.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3953c6926a63f8ea5514644b7afb42659b505ece4183fdaaa8f61d978754349e"}, - {file = "rpds_py-0.10.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2bb2e4826be25e72013916eecd3d30f66fd076110de09f0e750163b416500721"}, - {file = "rpds_py-0.10.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7bf347b495b197992efc81a7408e9a83b931b2f056728529956a4d0858608b80"}, - {file = "rpds_py-0.10.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:102eac53bb0bf0f9a275b438e6cf6904904908562a1463a6fc3323cf47d7a532"}, - {file = "rpds_py-0.10.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:40f93086eef235623aa14dbddef1b9fb4b22b99454cb39a8d2e04c994fb9868c"}, - {file = "rpds_py-0.10.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e22260a4741a0e7a206e175232867b48a16e0401ef5bce3c67ca5b9705879066"}, - {file = "rpds_py-0.10.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f4e56860a5af16a0fcfa070a0a20c42fbb2012eed1eb5ceeddcc7f8079214281"}, - {file = "rpds_py-0.10.6-cp311-none-win32.whl", hash = "sha256:0774a46b38e70fdde0c6ded8d6d73115a7c39d7839a164cc833f170bbf539116"}, - {file = "rpds_py-0.10.6-cp311-none-win_amd64.whl", hash = "sha256:4a5ee600477b918ab345209eddafde9f91c0acd931f3776369585a1c55b04c57"}, - {file = "rpds_py-0.10.6-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:5ee97c683eaface61d38ec9a489e353d36444cdebb128a27fe486a291647aff6"}, - {file = "rpds_py-0.10.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0713631d6e2d6c316c2f7b9320a34f44abb644fc487b77161d1724d883662e31"}, - {file = "rpds_py-0.10.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5a53f5998b4bbff1cb2e967e66ab2addc67326a274567697379dd1e326bded7"}, - {file = "rpds_py-0.10.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6a555ae3d2e61118a9d3e549737bb4a56ff0cec88a22bd1dfcad5b4e04759175"}, - {file = "rpds_py-0.10.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:945eb4b6bb8144909b203a88a35e0a03d22b57aefb06c9b26c6e16d72e5eb0f0"}, - {file = "rpds_py-0.10.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:52c215eb46307c25f9fd2771cac8135d14b11a92ae48d17968eda5aa9aaf5071"}, - {file = "rpds_py-0.10.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1b3cd23d905589cb205710b3988fc8f46d4a198cf12862887b09d7aaa6bf9b9"}, - {file = "rpds_py-0.10.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:64ccc28683666672d7c166ed465c09cee36e306c156e787acef3c0c62f90da5a"}, - {file = "rpds_py-0.10.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:516a611a2de12fbea70c78271e558f725c660ce38e0006f75139ba337d56b1f6"}, - {file = "rpds_py-0.10.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9ff93d3aedef11f9c4540cf347f8bb135dd9323a2fc705633d83210d464c579d"}, - {file = "rpds_py-0.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d858532212f0650be12b6042ff4378dc2efbb7792a286bee4489eaa7ba010586"}, - {file = "rpds_py-0.10.6-cp312-none-win32.whl", hash = "sha256:3c4eff26eddac49d52697a98ea01b0246e44ca82ab09354e94aae8823e8bda02"}, - {file = "rpds_py-0.10.6-cp312-none-win_amd64.whl", hash = "sha256:150eec465dbc9cbca943c8e557a21afdcf9bab8aaabf386c44b794c2f94143d2"}, - {file = "rpds_py-0.10.6-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:cf693eb4a08eccc1a1b636e4392322582db2a47470d52e824b25eca7a3977b53"}, - {file = "rpds_py-0.10.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4134aa2342f9b2ab6c33d5c172e40f9ef802c61bb9ca30d21782f6e035ed0043"}, - {file = "rpds_py-0.10.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e782379c2028a3611285a795b89b99a52722946d19fc06f002f8b53e3ea26ea9"}, - {file = "rpds_py-0.10.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2f6da6d842195fddc1cd34c3da8a40f6e99e4a113918faa5e60bf132f917c247"}, - {file = "rpds_py-0.10.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b4a9fe992887ac68256c930a2011255bae0bf5ec837475bc6f7edd7c8dfa254e"}, - {file = "rpds_py-0.10.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b788276a3c114e9f51e257f2a6f544c32c02dab4aa7a5816b96444e3f9ffc336"}, - {file = "rpds_py-0.10.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:caa1afc70a02645809c744eefb7d6ee8fef7e2fad170ffdeacca267fd2674f13"}, - {file = "rpds_py-0.10.6-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bddd4f91eede9ca5275e70479ed3656e76c8cdaaa1b354e544cbcf94c6fc8ac4"}, - {file = "rpds_py-0.10.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:775049dfa63fb58293990fc59473e659fcafd953bba1d00fc5f0631a8fd61977"}, - {file = "rpds_py-0.10.6-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:c6c45a2d2b68c51fe3d9352733fe048291e483376c94f7723458cfd7b473136b"}, - {file = "rpds_py-0.10.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0699ab6b8c98df998c3eacf51a3b25864ca93dab157abe358af46dc95ecd9801"}, - {file = "rpds_py-0.10.6-cp38-none-win32.whl", hash = "sha256:ebdab79f42c5961682654b851f3f0fc68e6cc7cd8727c2ac4ffff955154123c1"}, - {file = "rpds_py-0.10.6-cp38-none-win_amd64.whl", hash = "sha256:24656dc36f866c33856baa3ab309da0b6a60f37d25d14be916bd3e79d9f3afcf"}, - {file = "rpds_py-0.10.6-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:0898173249141ee99ffcd45e3829abe7bcee47d941af7434ccbf97717df020e5"}, - {file = "rpds_py-0.10.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9e9184fa6c52a74a5521e3e87badbf9692549c0fcced47443585876fcc47e469"}, - {file = "rpds_py-0.10.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5752b761902cd15073a527b51de76bbae63d938dc7c5c4ad1e7d8df10e765138"}, - {file = "rpds_py-0.10.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:99a57006b4ec39dbfb3ed67e5b27192792ffb0553206a107e4aadb39c5004cd5"}, - {file = "rpds_py-0.10.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09586f51a215d17efdb3a5f090d7cbf1633b7f3708f60a044757a5d48a83b393"}, - {file = "rpds_py-0.10.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e225a6a14ecf44499aadea165299092ab0cba918bb9ccd9304eab1138844490b"}, - {file = "rpds_py-0.10.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2039f8d545f20c4e52713eea51a275e62153ee96c8035a32b2abb772b6fc9e5"}, - {file = "rpds_py-0.10.6-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:34ad87a831940521d462ac11f1774edf867c34172010f5390b2f06b85dcc6014"}, - {file = "rpds_py-0.10.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:dcdc88b6b01015da066da3fb76545e8bb9a6880a5ebf89e0f0b2e3ca557b3ab7"}, - {file = "rpds_py-0.10.6-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:25860ed5c4e7f5e10c496ea78af46ae8d8468e0be745bd233bab9ca99bfd2647"}, - {file = "rpds_py-0.10.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7854a207ef77319ec457c1eb79c361b48807d252d94348305db4f4b62f40f7f3"}, - {file = "rpds_py-0.10.6-cp39-none-win32.whl", hash = "sha256:e6fcc026a3f27c1282c7ed24b7fcac82cdd70a0e84cc848c0841a3ab1e3dea2d"}, - {file = "rpds_py-0.10.6-cp39-none-win_amd64.whl", hash = "sha256:e98c4c07ee4c4b3acf787e91b27688409d918212dfd34c872201273fdd5a0e18"}, - {file = "rpds_py-0.10.6-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:68fe9199184c18d997d2e4293b34327c0009a78599ce703e15cd9a0f47349bba"}, - {file = "rpds_py-0.10.6-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:3339eca941568ed52d9ad0f1b8eb9fe0958fa245381747cecf2e9a78a5539c42"}, - {file = "rpds_py-0.10.6-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a360cfd0881d36c6dc271992ce1eda65dba5e9368575663de993eeb4523d895f"}, - {file = "rpds_py-0.10.6-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:031f76fc87644a234883b51145e43985aa2d0c19b063e91d44379cd2786144f8"}, - {file = "rpds_py-0.10.6-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1f36a9d751f86455dc5278517e8b65580eeee37d61606183897f122c9e51cef3"}, - {file = "rpds_py-0.10.6-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:052a832078943d2b2627aea0d19381f607fe331cc0eb5df01991268253af8417"}, - {file = "rpds_py-0.10.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:023574366002bf1bd751ebaf3e580aef4a468b3d3c216d2f3f7e16fdabd885ed"}, - {file = "rpds_py-0.10.6-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:defa2c0c68734f4a82028c26bcc85e6b92cced99866af118cd6a89b734ad8e0d"}, - {file = "rpds_py-0.10.6-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:879fb24304ead6b62dbe5034e7b644b71def53c70e19363f3c3be2705c17a3b4"}, - {file = "rpds_py-0.10.6-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:53c43e10d398e365da2d4cc0bcaf0854b79b4c50ee9689652cdc72948e86f487"}, - {file = "rpds_py-0.10.6-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:3777cc9dea0e6c464e4b24760664bd8831738cc582c1d8aacf1c3f546bef3f65"}, - {file = "rpds_py-0.10.6-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:40578a6469e5d1df71b006936ce95804edb5df47b520c69cf5af264d462f2cbb"}, - {file = "rpds_py-0.10.6-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:cf71343646756a072b85f228d35b1d7407da1669a3de3cf47f8bbafe0c8183a4"}, - {file = "rpds_py-0.10.6-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:10f32b53f424fc75ff7b713b2edb286fdbfc94bf16317890260a81c2c00385dc"}, - {file = "rpds_py-0.10.6-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:81de24a1c51cfb32e1fbf018ab0bdbc79c04c035986526f76c33e3f9e0f3356c"}, - {file = "rpds_py-0.10.6-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ac17044876e64a8ea20ab132080ddc73b895b4abe9976e263b0e30ee5be7b9c2"}, - {file = "rpds_py-0.10.6-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5e8a78bd4879bff82daef48c14d5d4057f6856149094848c3ed0ecaf49f5aec2"}, - {file = "rpds_py-0.10.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78ca33811e1d95cac8c2e49cb86c0fb71f4d8409d8cbea0cb495b6dbddb30a55"}, - {file = "rpds_py-0.10.6-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c63c3ef43f0b3fb00571cff6c3967cc261c0ebd14a0a134a12e83bdb8f49f21f"}, - {file = "rpds_py-0.10.6-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:7fde6d0e00b2fd0dbbb40c0eeec463ef147819f23725eda58105ba9ca48744f4"}, - {file = "rpds_py-0.10.6-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:79edd779cfc46b2e15b0830eecd8b4b93f1a96649bcb502453df471a54ce7977"}, - {file = "rpds_py-0.10.6-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:9164ec8010327ab9af931d7ccd12ab8d8b5dc2f4c6a16cbdd9d087861eaaefa1"}, - {file = "rpds_py-0.10.6-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:d29ddefeab1791e3c751e0189d5f4b3dbc0bbe033b06e9c333dca1f99e1d523e"}, - {file = "rpds_py-0.10.6-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:30adb75ecd7c2a52f5e76af50644b3e0b5ba036321c390b8e7ec1bb2a16dd43c"}, - {file = "rpds_py-0.10.6-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd609fafdcdde6e67a139898196698af37438b035b25ad63704fd9097d9a3482"}, - {file = "rpds_py-0.10.6-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6eef672de005736a6efd565577101277db6057f65640a813de6c2707dc69f396"}, - {file = "rpds_py-0.10.6-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6cf4393c7b41abbf07c88eb83e8af5013606b1cdb7f6bc96b1b3536b53a574b8"}, - {file = "rpds_py-0.10.6-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ad857f42831e5b8d41a32437f88d86ead6c191455a3499c4b6d15e007936d4cf"}, - {file = "rpds_py-0.10.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d7360573f1e046cb3b0dceeb8864025aa78d98be4bb69f067ec1c40a9e2d9df"}, - {file = "rpds_py-0.10.6-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d08f63561c8a695afec4975fae445245386d645e3e446e6f260e81663bfd2e38"}, - {file = "rpds_py-0.10.6-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:f0f17f2ce0f3529177a5fff5525204fad7b43dd437d017dd0317f2746773443d"}, - {file = "rpds_py-0.10.6-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:442626328600bde1d09dc3bb00434f5374948838ce75c41a52152615689f9403"}, - {file = "rpds_py-0.10.6-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:e9616f5bd2595f7f4a04b67039d890348ab826e943a9bfdbe4938d0eba606971"}, - {file = "rpds_py-0.10.6.tar.gz", hash = "sha256:4ce5a708d65a8dbf3748d2474b580d606b1b9f91b5c6ab2a316e0b0cf7a4ba50"}, + {file = "rpds_py-0.13.2-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:1ceebd0ae4f3e9b2b6b553b51971921853ae4eebf3f54086be0565d59291e53d"}, + {file = "rpds_py-0.13.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:46e1ed994a0920f350a4547a38471217eb86f57377e9314fbaaa329b71b7dfe3"}, + {file = "rpds_py-0.13.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee353bb51f648924926ed05e0122b6a0b1ae709396a80eb583449d5d477fcdf7"}, + {file = "rpds_py-0.13.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:530190eb0cd778363bbb7596612ded0bb9fef662daa98e9d92a0419ab27ae914"}, + {file = "rpds_py-0.13.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29d311e44dd16d2434d5506d57ef4d7036544fc3c25c14b6992ef41f541b10fb"}, + {file = "rpds_py-0.13.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e72f750048b32d39e87fc85c225c50b2a6715034848dbb196bf3348aa761fa1"}, + {file = "rpds_py-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db09b98c7540df69d4b47218da3fbd7cb466db0fb932e971c321f1c76f155266"}, + {file = "rpds_py-0.13.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2ac26f50736324beb0282c819668328d53fc38543fa61eeea2c32ea8ea6eab8d"}, + {file = "rpds_py-0.13.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:12ecf89bd54734c3c2c79898ae2021dca42750c7bcfb67f8fb3315453738ac8f"}, + {file = "rpds_py-0.13.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3a44c8440183b43167fd1a0819e8356692bf5db1ad14ce140dbd40a1485f2dea"}, + {file = "rpds_py-0.13.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bcef4f2d3dc603150421de85c916da19471f24d838c3c62a4f04c1eb511642c1"}, + {file = "rpds_py-0.13.2-cp310-none-win32.whl", hash = "sha256:ee6faebb265e28920a6f23a7d4c362414b3f4bb30607141d718b991669e49ddc"}, + {file = "rpds_py-0.13.2-cp310-none-win_amd64.whl", hash = "sha256:ac96d67b37f28e4b6ecf507c3405f52a40658c0a806dffde624a8fcb0314d5fd"}, + {file = "rpds_py-0.13.2-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:b5f6328e8e2ae8238fc767703ab7b95785521c42bb2b8790984e3477d7fa71ad"}, + {file = "rpds_py-0.13.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:729408136ef8d45a28ee9a7411917c9e3459cf266c7e23c2f7d4bb8ef9e0da42"}, + {file = "rpds_py-0.13.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:65cfed9c807c27dee76407e8bb29e6f4e391e436774bcc769a037ff25ad8646e"}, + {file = "rpds_py-0.13.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:aefbdc934115d2f9278f153952003ac52cd2650e7313750390b334518c589568"}, + {file = "rpds_py-0.13.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d48db29bd47814671afdd76c7652aefacc25cf96aad6daefa82d738ee87461e2"}, + {file = "rpds_py-0.13.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3c55d7f2d817183d43220738270efd3ce4e7a7b7cbdaefa6d551ed3d6ed89190"}, + {file = "rpds_py-0.13.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6aadae3042f8e6db3376d9e91f194c606c9a45273c170621d46128f35aef7cd0"}, + {file = "rpds_py-0.13.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5feae2f9aa7270e2c071f488fab256d768e88e01b958f123a690f1cc3061a09c"}, + {file = "rpds_py-0.13.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:51967a67ea0d7b9b5cd86036878e2d82c0b6183616961c26d825b8c994d4f2c8"}, + {file = "rpds_py-0.13.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4d0c10d803549427f427085ed7aebc39832f6e818a011dcd8785e9c6a1ba9b3e"}, + {file = "rpds_py-0.13.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:603d5868f7419081d616dab7ac3cfa285296735e7350f7b1e4f548f6f953ee7d"}, + {file = "rpds_py-0.13.2-cp311-none-win32.whl", hash = "sha256:b8996ffb60c69f677245f5abdbcc623e9442bcc91ed81b6cd6187129ad1fa3e7"}, + {file = "rpds_py-0.13.2-cp311-none-win_amd64.whl", hash = "sha256:5379e49d7e80dca9811b36894493d1c1ecb4c57de05c36f5d0dd09982af20211"}, + {file = "rpds_py-0.13.2-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:8a776a29b77fe0cc28fedfd87277b0d0f7aa930174b7e504d764e0b43a05f381"}, + {file = "rpds_py-0.13.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2a1472956c5bcc49fb0252b965239bffe801acc9394f8b7c1014ae9258e4572b"}, + {file = "rpds_py-0.13.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f252dfb4852a527987a9156cbcae3022a30f86c9d26f4f17b8c967d7580d65d2"}, + {file = "rpds_py-0.13.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f0d320e70b6b2300ff6029e234e79fe44e9dbbfc7b98597ba28e054bd6606a57"}, + {file = "rpds_py-0.13.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ade2ccb937060c299ab0dfb2dea3d2ddf7e098ed63ee3d651ebfc2c8d1e8632a"}, + {file = "rpds_py-0.13.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b9d121be0217787a7d59a5c6195b0842d3f701007333426e5154bf72346aa658"}, + {file = "rpds_py-0.13.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8fa6bd071ec6d90f6e7baa66ae25820d57a8ab1b0a3c6d3edf1834d4b26fafa2"}, + {file = "rpds_py-0.13.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c918621ee0a3d1fe61c313f2489464f2ae3d13633e60f520a8002a5e910982ee"}, + {file = "rpds_py-0.13.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:25b28b3d33ec0a78e944aaaed7e5e2a94ac811bcd68b557ca48a0c30f87497d2"}, + {file = "rpds_py-0.13.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:31e220a040b89a01505128c2f8a59ee74732f666439a03e65ccbf3824cdddae7"}, + {file = "rpds_py-0.13.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:15253fff410873ebf3cfba1cc686a37711efcd9b8cb30ea21bb14a973e393f60"}, + {file = "rpds_py-0.13.2-cp312-none-win32.whl", hash = "sha256:b981a370f8f41c4024c170b42fbe9e691ae2dbc19d1d99151a69e2c84a0d194d"}, + {file = "rpds_py-0.13.2-cp312-none-win_amd64.whl", hash = "sha256:4c4e314d36d4f31236a545696a480aa04ea170a0b021e9a59ab1ed94d4c3ef27"}, + {file = "rpds_py-0.13.2-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:80e5acb81cb49fd9f2d5c08f8b74ffff14ee73b10ca88297ab4619e946bcb1e1"}, + {file = "rpds_py-0.13.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:efe093acc43e869348f6f2224df7f452eab63a2c60a6c6cd6b50fd35c4e075ba"}, + {file = "rpds_py-0.13.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c2a61c0e4811012b0ba9f6cdcb4437865df5d29eab5d6018ba13cee1c3064a0"}, + {file = "rpds_py-0.13.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:751758d9dd04d548ec679224cc00e3591f5ebf1ff159ed0d4aba6a0746352452"}, + {file = "rpds_py-0.13.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6ba8858933f0c1a979781272a5f65646fca8c18c93c99c6ddb5513ad96fa54b1"}, + {file = "rpds_py-0.13.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bfdfbe6a36bc3059fff845d64c42f2644cf875c65f5005db54f90cdfdf1df815"}, + {file = "rpds_py-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa0379c1935c44053c98826bc99ac95f3a5355675a297ac9ce0dfad0ce2d50ca"}, + {file = "rpds_py-0.13.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d5593855b5b2b73dd8413c3fdfa5d95b99d657658f947ba2c4318591e745d083"}, + {file = "rpds_py-0.13.2-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:2a7bef6977043673750a88da064fd513f89505111014b4e00fbdd13329cd4e9a"}, + {file = "rpds_py-0.13.2-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:3ab96754d23372009638a402a1ed12a27711598dd49d8316a22597141962fe66"}, + {file = "rpds_py-0.13.2-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:e06cfea0ece444571d24c18ed465bc93afb8c8d8d74422eb7026662f3d3f779b"}, + {file = "rpds_py-0.13.2-cp38-none-win32.whl", hash = "sha256:5493569f861fb7b05af6d048d00d773c6162415ae521b7010197c98810a14cab"}, + {file = "rpds_py-0.13.2-cp38-none-win_amd64.whl", hash = "sha256:b07501b720cf060c5856f7b5626e75b8e353b5f98b9b354a21eb4bfa47e421b1"}, + {file = "rpds_py-0.13.2-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:881df98f0a8404d32b6de0fd33e91c1b90ed1516a80d4d6dc69d414b8850474c"}, + {file = "rpds_py-0.13.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d79c159adea0f1f4617f54aa156568ac69968f9ef4d1e5fefffc0a180830308e"}, + {file = "rpds_py-0.13.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38d4f822ee2f338febcc85aaa2547eb5ba31ba6ff68d10b8ec988929d23bb6b4"}, + {file = "rpds_py-0.13.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5d75d6d220d55cdced2f32cc22f599475dbe881229aeddba6c79c2e9df35a2b3"}, + {file = "rpds_py-0.13.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5d97e9ae94fb96df1ee3cb09ca376c34e8a122f36927230f4c8a97f469994bff"}, + {file = "rpds_py-0.13.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:67a429520e97621a763cf9b3ba27574779c4e96e49a27ff8a1aa99ee70beb28a"}, + {file = "rpds_py-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:188435794405c7f0573311747c85a96b63c954a5f2111b1df8018979eca0f2f0"}, + {file = "rpds_py-0.13.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:38f9bf2ad754b4a45b8210a6c732fe876b8a14e14d5992a8c4b7c1ef78740f53"}, + {file = "rpds_py-0.13.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:a6ba2cb7d676e9415b9e9ac7e2aae401dc1b1e666943d1f7bc66223d3d73467b"}, + {file = "rpds_py-0.13.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:eaffbd8814bb1b5dc3ea156a4c5928081ba50419f9175f4fc95269e040eff8f0"}, + {file = "rpds_py-0.13.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5a4c1058cdae6237d97af272b326e5f78ee7ee3bbffa6b24b09db4d828810468"}, + {file = "rpds_py-0.13.2-cp39-none-win32.whl", hash = "sha256:b5267feb19070bef34b8dea27e2b504ebd9d31748e3ecacb3a4101da6fcb255c"}, + {file = "rpds_py-0.13.2-cp39-none-win_amd64.whl", hash = "sha256:ddf23960cb42b69bce13045d5bc66f18c7d53774c66c13f24cf1b9c144ba3141"}, + {file = "rpds_py-0.13.2-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:97163a1ab265a1073a6372eca9f4eeb9f8c6327457a0b22ddfc4a17dcd613e74"}, + {file = "rpds_py-0.13.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:25ea41635d22b2eb6326f58e608550e55d01df51b8a580ea7e75396bafbb28e9"}, + {file = "rpds_py-0.13.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76d59d4d451ba77f08cb4cd9268dec07be5bc65f73666302dbb5061989b17198"}, + {file = "rpds_py-0.13.2-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e7c564c58cf8f248fe859a4f0fe501b050663f3d7fbc342172f259124fb59933"}, + {file = "rpds_py-0.13.2-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61dbc1e01dc0c5875da2f7ae36d6e918dc1b8d2ce04e871793976594aad8a57a"}, + {file = "rpds_py-0.13.2-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fdb82eb60d31b0c033a8e8ee9f3fc7dfbaa042211131c29da29aea8531b4f18f"}, + {file = "rpds_py-0.13.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d204957169f0b3511fb95395a9da7d4490fb361763a9f8b32b345a7fe119cb45"}, + {file = "rpds_py-0.13.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c45008ca79bad237cbc03c72bc5205e8c6f66403773929b1b50f7d84ef9e4d07"}, + {file = "rpds_py-0.13.2-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:79bf58c08f0756adba691d480b5a20e4ad23f33e1ae121584cf3a21717c36dfa"}, + {file = "rpds_py-0.13.2-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:e86593bf8637659e6a6ed58854b6c87ec4e9e45ee8a4adfd936831cef55c2d21"}, + {file = "rpds_py-0.13.2-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:d329896c40d9e1e5c7715c98529e4a188a1f2df51212fd65102b32465612b5dc"}, + {file = "rpds_py-0.13.2-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:4a5375c5fff13f209527cd886dc75394f040c7d1ecad0a2cb0627f13ebe78a12"}, + {file = "rpds_py-0.13.2-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:06d218e4464d31301e943b65b2c6919318ea6f69703a351961e1baaf60347276"}, + {file = "rpds_py-0.13.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1f41d32a2ddc5a94df4b829b395916a4b7f103350fa76ba6de625fcb9e773ac"}, + {file = "rpds_py-0.13.2-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6bc568b05e02cd612be53900c88aaa55012e744930ba2eeb56279db4c6676eb3"}, + {file = "rpds_py-0.13.2-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d94d78418203904730585efa71002286ac4c8ac0689d0eb61e3c465f9e608ff"}, + {file = "rpds_py-0.13.2-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bed0252c85e21cf73d2d033643c945b460d6a02fc4a7d644e3b2d6f5f2956c64"}, + {file = "rpds_py-0.13.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:244e173bb6d8f3b2f0c4d7370a1aa341f35da3e57ffd1798e5b2917b91731fd3"}, + {file = "rpds_py-0.13.2-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7f55cd9cf1564b7b03f238e4c017ca4794c05b01a783e9291065cb2858d86ce4"}, + {file = "rpds_py-0.13.2-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:f03a1b3a4c03e3e0161642ac5367f08479ab29972ea0ffcd4fa18f729cd2be0a"}, + {file = "rpds_py-0.13.2-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:f5f4424cb87a20b016bfdc157ff48757b89d2cc426256961643d443c6c277007"}, + {file = "rpds_py-0.13.2-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:c82bbf7e03748417c3a88c1b0b291288ce3e4887a795a3addaa7a1cfd9e7153e"}, + {file = "rpds_py-0.13.2-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:c0095b8aa3e432e32d372e9a7737e65b58d5ed23b9620fea7cb81f17672f1fa1"}, + {file = "rpds_py-0.13.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:4c2d26aa03d877c9730bf005621c92da263523a1e99247590abbbe252ccb7824"}, + {file = "rpds_py-0.13.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96f2975fb14f39c5fe75203f33dd3010fe37d1c4e33177feef1107b5ced750e3"}, + {file = "rpds_py-0.13.2-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4dcc5ee1d0275cb78d443fdebd0241e58772a354a6d518b1d7af1580bbd2c4e8"}, + {file = "rpds_py-0.13.2-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61d42d2b08430854485135504f672c14d4fc644dd243a9c17e7c4e0faf5ed07e"}, + {file = "rpds_py-0.13.2-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d3a61e928feddc458a55110f42f626a2a20bea942ccedb6fb4cee70b4830ed41"}, + {file = "rpds_py-0.13.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7de12b69d95072394998c622cfd7e8cea8f560db5fca6a62a148f902a1029f8b"}, + {file = "rpds_py-0.13.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:87a90f5545fd61f6964e65eebde4dc3fa8660bb7d87adb01d4cf17e0a2b484ad"}, + {file = "rpds_py-0.13.2-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:9c95a1a290f9acf7a8f2ebbdd183e99215d491beea52d61aa2a7a7d2c618ddc6"}, + {file = "rpds_py-0.13.2-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:35f53c76a712e323c779ca39b9a81b13f219a8e3bc15f106ed1e1462d56fcfe9"}, + {file = "rpds_py-0.13.2-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:96fb0899bb2ab353f42e5374c8f0789f54e0a94ef2f02b9ac7149c56622eaf31"}, + {file = "rpds_py-0.13.2.tar.gz", hash = "sha256:f8eae66a1304de7368932b42d801c67969fd090ddb1a7a24f27b435ed4bed68f"}, ] [[package]] name = "safetensors" -version = "0.4.0" +version = "0.4.1" description = "" optional = false python-versions = ">=3.7" files = [ - {file = "safetensors-0.4.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:2289ae6dbe6d027ecee016b28ced13a2e21a0b3a3a757a23033a2d1c0b1bad55"}, - {file = "safetensors-0.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bf6458959f310f551cbbeef2255527ade5f783f952738e73e4d0136198cc3bfe"}, - {file = "safetensors-0.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b6b60a58a8f7cc7aed3b5b73dce1f5259a53c83d9ba43a76a874e6ad868c1b4d"}, - {file = "safetensors-0.4.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:491b3477e4d0d4599bb75d79da4b75af2e6ed9b1f6ec2b715991f0bc927bf09a"}, - {file = "safetensors-0.4.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59d2e10b7e0cd18bb73ed7c17c624a5957b003b81345e18159591771c26ee428"}, - {file = "safetensors-0.4.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3f667a4c12fb593f5f66ce966cb1b14a7148898b2b1a7f79e0761040ae1e3c51"}, - {file = "safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f9909512bcb6f712bdd04c296cdfb0d8ff73d258ffc5af884bb62ea02d221e0"}, - {file = "safetensors-0.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d33d29e846821f0e4f92614022949b09ccf063cb36fe2f9fe099cde1efbfbb87"}, - {file = "safetensors-0.4.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4d512525a8e05a045ce6698066ba0c5378c174a83e0b3720a8c7799dc1bb06f3"}, - {file = "safetensors-0.4.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0219cea445177f6ad1f9acd3a8d025440c8ff436d70a4a7c7ba9c36066aa9474"}, - {file = "safetensors-0.4.0-cp310-none-win32.whl", hash = "sha256:67ab171eeaad6972d3971c53d29d53353c67f6743284c6d637b59fa3e54c8a94"}, - {file = "safetensors-0.4.0-cp310-none-win_amd64.whl", hash = "sha256:7ffc736039f08a9ca1f09816a7481b8e4469c06e8f8a5ffa8cb67ddd79e6d77f"}, - {file = "safetensors-0.4.0-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:4fe9e3737b30de458225a23926219ca30b902ee779b6a3df96eaab2b6d625ec2"}, - {file = "safetensors-0.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e7916e814a90008de767b1c164a1d83803693c661ffe9af5a697b22e2752edb0"}, - {file = "safetensors-0.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cbc4a4da01143472323c145f3c289e5f6fabde0ac0a3414dabf912a21692fff4"}, - {file = "safetensors-0.4.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a54c21654a47669b38e359e8f852af754b786c9da884bb61ad5e9af12bd71ccb"}, - {file = "safetensors-0.4.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:25cd407955bad5340ba17f9f8ac789a0d751601a311e2f7b2733f9384478c95e"}, - {file = "safetensors-0.4.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:82e8fc4e3503cd738fd40718a430fe0e5ce6e7ff91a73d6ce628bbb89c41e8ce"}, - {file = "safetensors-0.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48b92059b1a4ad163024d4f526e0e73ebe2bb3ae70537e15e347820b4de5dc27"}, - {file = "safetensors-0.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5daa05058f7dce85b5f9f60c4eab483ed7859d63978f08a76e52e78859ff20ca"}, - {file = "safetensors-0.4.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a86565a5c112dd855909e20144947b4f53abb78c4de207f36ca71ee63ba5b90d"}, - {file = "safetensors-0.4.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:38032078ed9fea52d06584e441bccc73fb475c4581600c6d6166de2fe2deb3d1"}, - {file = "safetensors-0.4.0-cp311-none-win32.whl", hash = "sha256:2f99d90c91b7c76b40a862acd9085bc77f7974a27dee7cfcebe46149af5a99a1"}, - {file = "safetensors-0.4.0-cp311-none-win_amd64.whl", hash = "sha256:74e2a448ffe19be188b457b130168190ee73b5a75e45ba96796320c1f5ae35d2"}, - {file = "safetensors-0.4.0-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:1e2f9c69b41d03b4826ffb96b29e07444bb6b34a78a7bafd0b88d59e8ec75b8a"}, - {file = "safetensors-0.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3910fb5bf747413b59f1a34e6d2a993b589fa7d919709518823c70efaaa350bd"}, - {file = "safetensors-0.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf8fdca709b2470a35a59b1e6dffea75cbe1214b22612b5dd4c93947697aea8b"}, - {file = "safetensors-0.4.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2f27b8ef814c5fb43456caeb7f3cbb889b76115180aad1f42402839c14a47c5b"}, - {file = "safetensors-0.4.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7b2d6101eccc43c7be0cb052f13ceda64288b3d8b344b988ed08d7133cbce2f3"}, - {file = "safetensors-0.4.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fdc34027b545a69be3d4220c140b276129523e4e46db06ad1a0b60d6a4cf9214"}, - {file = "safetensors-0.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db7bb48ca9e90bb9526c71b388d38d8de160c0354f4c5126df23e8701a870dcb"}, - {file = "safetensors-0.4.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a78ffc0795d3595cd9e4d453502e35f764276c49e434b25556a15a337db4dafc"}, - {file = "safetensors-0.4.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:8e735b0f79090f6855b55e205e820b7b595502ffca0009a5c13eef3661ce465b"}, - {file = "safetensors-0.4.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f8d2416734e850d5392afffbcb2b8985ea29fb171f1cb197e2ae51b8e35d6438"}, - {file = "safetensors-0.4.0-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:e853e189ba7d47eaf561094586692ba2bbdd258c096f1755805cac098de0e6ab"}, - {file = "safetensors-0.4.0-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:4b2aa57b5a4d576f3d1dd6e56980026340f156f8a13c13016bfac4e25295b53f"}, - {file = "safetensors-0.4.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b6c1316ffde6cb4bf22c7445bc9fd224b4d1b9dd7320695f5611c89e802e4b6"}, - {file = "safetensors-0.4.0-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:003077ec85261d00061058fa12e3c1d2055366b02ce8f2938929359ffbaff2b8"}, - {file = "safetensors-0.4.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bd63d83a92f1437a8b0431779320376030ae43ace980bea5686d515de0784100"}, - {file = "safetensors-0.4.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2077801800b4b13301d8d6290c7fb5bd60737320001717153ebc4371776643b5"}, - {file = "safetensors-0.4.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7abe0e157a49a75aeeccfbc4f3dac38d8f98512d3cdb35c200f8e628dc5773cf"}, - {file = "safetensors-0.4.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3bfed574f6b1e7e7fe1f17213278875ef6c6e8b1582ab6eda93947db1178cae6"}, - {file = "safetensors-0.4.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:964ef166a286ce3b023d0d0bd0e21d440a1c8028981c8abdb136bc7872ba9b3d"}, - {file = "safetensors-0.4.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:44f84373e42183bd56a13a1f2d8acb1db7fedaeffbd83e79cec861477eee1af4"}, - {file = "safetensors-0.4.0-cp37-none-win32.whl", hash = "sha256:c68132727dd86fb641102e494d445f705efe402f4d5e24b278183a15499ab400"}, - {file = "safetensors-0.4.0-cp37-none-win_amd64.whl", hash = "sha256:1db87155454c168aef118d5657a403aee48a4cb08d8851a981157f07351ea317"}, - {file = "safetensors-0.4.0-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:9e583fa68e5a07cc859c4e13c1ebff12029904aa2e27185cf04a1f57fe9a81c4"}, - {file = "safetensors-0.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:73e7696dcf3f72f99545eb1abe6106ad65ff1f62381d6ce4b34be3272552897a"}, - {file = "safetensors-0.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4936096a57c62e84e200f92620a536be067fc5effe46ecc7f230ebb496ecd579"}, - {file = "safetensors-0.4.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87b328ee1591adac332543e1f5fc2c2d7f149b745ebb0d58d7850818ff9cee27"}, - {file = "safetensors-0.4.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b69554c143336256260eceff1d3c0969172a641b54d4668489a711b05f92a2c0"}, - {file = "safetensors-0.4.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3ebf6bcece5d5d1bd6416472f94604d2c834ca752ac60ed42dba7157e595a990"}, - {file = "safetensors-0.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6686ce01b8602d55a7d9903c90d4a6e6f90aeb6ddced7cf4605892d0ba94bcb8"}, - {file = "safetensors-0.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9b8fd6cc2f3bda444a048b541c843c7b7fefc89c4120d7898ea7d5b026e93891"}, - {file = "safetensors-0.4.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8a6abfe67692f81b8bdb99c837f28351c17e624ebf136970c850ee989c720446"}, - {file = "safetensors-0.4.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:27a24ca8822c469ee452db4c13418ba983315a0d863c018a9af15f2305eac38c"}, - {file = "safetensors-0.4.0-cp38-none-win32.whl", hash = "sha256:c4a0a47c8640167792d8261ee21b26430bbc39130a7edaad7f4c0bc05669d00e"}, - {file = "safetensors-0.4.0-cp38-none-win_amd64.whl", hash = "sha256:a738970a367f39249e2abb900d9441a8a86d7ff50083e5eaa6e7760a9f216014"}, - {file = "safetensors-0.4.0-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:806379f37e1abd5d302288c4b2f4186dd7ea7143d4c7811f90a8077f0ae8967b"}, - {file = "safetensors-0.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2b9b94133ed2ae9dda0e95dcace7b7556eba023ffa4c4ae6df8f99377f571d6a"}, - {file = "safetensors-0.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b563a14c43614815a6b524d2e4edeaace50b717f7e7487bb227dd5b68350f5a"}, - {file = "safetensors-0.4.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:00a9b157be660fb7ba88fa2eedd05ec93793a5b61e43e783e10cb0b995372802"}, - {file = "safetensors-0.4.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c8f194f45ab6aa767993c24f0aeb950af169dbc5d611b94c9021a1d13b8a1a34"}, - {file = "safetensors-0.4.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:469360b9451db10bfed3881378d5a71b347ecb1ab4f42367d77b8164a13af70b"}, - {file = "safetensors-0.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5f75fa97ccf32a3c7af476c6a0e851023197d3c078f6de3612008fff94735f9"}, - {file = "safetensors-0.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:acf0180283c2efae72f1d8c0a4a7974662091df01be3aa43b5237b1e52ed0a01"}, - {file = "safetensors-0.4.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:cd02b495ba0814619f40bda46771bb06dbbf1d42524b66fa03b2a736c77e4515"}, - {file = "safetensors-0.4.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c42bdea183dbaa99e2f0e6120dc524df79cf4289a6f90f30a534444ef20f49fa"}, - {file = "safetensors-0.4.0-cp39-none-win32.whl", hash = "sha256:cef7bb5d9feae7146c3c3c7b3aef7d2c8b39ba7f5ff4252d368eb69462a47076"}, - {file = "safetensors-0.4.0-cp39-none-win_amd64.whl", hash = "sha256:79dd46fb1f19282fd12f544471efb97823ede927cedbf9cf35550d92b349fdd2"}, - {file = "safetensors-0.4.0-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:002301c1afa32909f83745b0c124d002e7ae07e15671f3b43cbebd0ffc5e6037"}, - {file = "safetensors-0.4.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:67762d36ae088c73d4a3c96bfc4ea8d31233554f35b6cace3a18533238d462ea"}, - {file = "safetensors-0.4.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f45230f20a206e5e4c7f7bbf9342178410c6f8b0af889843aa99045a76f7691"}, - {file = "safetensors-0.4.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f2ca939bbd8fb2f4dfa28e39a146dad03bc9325e9fc831b68f7b98f69a5a2f1"}, - {file = "safetensors-0.4.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:61a00f281391fae5ce91df70918bb61c12d2d514a493fd8056e12114be729911"}, - {file = "safetensors-0.4.0-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:435fd136a42492b280cb55126f9ce9535b35dd49df2c5d572a5945455a439448"}, - {file = "safetensors-0.4.0-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f0daa788273d683258fb1e4a5e16bef4486b2fca536451a2591bc0f4a6488895"}, - {file = "safetensors-0.4.0-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:0620ab0d41e390ccb1c4ea8f63dc00cb5f0b96a5cdd3cd0d64c21765720c074a"}, - {file = "safetensors-0.4.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc1fa8d067733cb67f22926689ee808f08afacf7700d2ffb44efae90a0693eb1"}, - {file = "safetensors-0.4.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcaa40bc363edda145db75cd030f3b1822e5478d550c3500a42502ecef32c959"}, - {file = "safetensors-0.4.0-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b561fbc044db7beff2ece0ec219a291809d45a38d30c6b38e7cc46482582f4ba"}, - {file = "safetensors-0.4.0-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:79a983b09782dacf9a1adb19bb98f4a8f6c3144108939f572c047b5797e43cf5"}, - {file = "safetensors-0.4.0-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:10b65cd3ad79f5d0daf281523b4146bc271a34bb7430d4e03212e0de8622dab8"}, - {file = "safetensors-0.4.0-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:114decacc475a6a9e2f9102a00c171d113ddb5d35cb0bda0db2c0c82b2eaa9ce"}, - {file = "safetensors-0.4.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:72ddb741dd5fe42521db76a70e012f76995516a12e7e0ef26be03ea9be77802a"}, - {file = "safetensors-0.4.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c5556c2ec75f5a6134866eddd7341cb36062e6edaea343478a279591b63ddba"}, - {file = "safetensors-0.4.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed50f239b0ce7ae85b078395593b4a351ede7e6f73af25f4873e3392336f64c9"}, - {file = "safetensors-0.4.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:495dcaea8fbab70b927d2274e2547824462737acbf98ccd851a71124f779a5c6"}, - {file = "safetensors-0.4.0-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:3f4d90c79a65ba2fe2ff0876f6140748f0a3ce6a21e27a35190f4f96321803f8"}, - {file = "safetensors-0.4.0-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:7a524382b5c55b5fbb168e0e9d3f502450c8cf3fb81b93e880018437c206a482"}, - {file = "safetensors-0.4.0-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:9849ea60c7e840bfdd6030ad454d4a6ba837b3398c902f15a30460dd6961c28c"}, - {file = "safetensors-0.4.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:6c42623ae7045615d9eaa6877b9df1db4e9cc71ecc14bcc721ea1e475dddd595"}, - {file = "safetensors-0.4.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80cb8342f00f3c41b3b93b1a599b84723280d3ac90829bc62262efc03ab28793"}, - {file = "safetensors-0.4.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8c4f5ed4ede384dea8c99bae76b0718a828dbf7b2c8ced1f44e3b9b1a124475"}, - {file = "safetensors-0.4.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:40d7cf03493bfe75ef62e2c716314474b28d9ba5bf4909763e4b8dd14330c01a"}, - {file = "safetensors-0.4.0-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:232029f0a9fa6fa1f737324eda98a700409811186888536a2333cbbf64e41741"}, - {file = "safetensors-0.4.0-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:9ed55f4a20c78ff3e8477efb63c8303c2152cdfb3bfea4d025a80f54d38fd628"}, - {file = "safetensors-0.4.0.tar.gz", hash = "sha256:b985953c3cf11e942eac4317ef3db3da713e274109cf7cfb6076d877054f013e"}, + {file = "safetensors-0.4.1-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:cba01c6b76e01ec453933b3b3c0157c59b52881c83eaa0f7666244e71aa75fd1"}, + {file = "safetensors-0.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7a8f6f679d97ea0135c7935c202feefbd042c149aa70ee759855e890c01c7814"}, + {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbc2ce1f5ae5143a7fb72b71fa71db6a42b4f6cf912aa3acdc6b914084778e68"}, + {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2d87d993eaefe6611a9c241a8bd364a5f1ffed5771c74840363a6c4ed8d868f6"}, + {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:097e9af2efa8778cd2f0cba451784253e62fa7cc9fc73c0744d27212f7294e25"}, + {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d10a9f7bae608ccfdc009351f01dc3d8535ff57f9488a58a4c38e45bf954fe93"}, + {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:270b99885ec14abfd56c1d7f28ada81740a9220b4bae960c3de1c6fe84af9e4d"}, + {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:285b52a481e7ba93e29ad4ec5841ef2c4479ef0a6c633c4e2629e0508453577b"}, + {file = "safetensors-0.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c3c9f0ca510e0de95abd6424789dcbc879942a3a4e29b0dfa99d9427bf1da75c"}, + {file = "safetensors-0.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:88b4653059c903015284a9722f9a46838c654257173b279c8f6f46dbe80b612d"}, + {file = "safetensors-0.4.1-cp310-none-win32.whl", hash = "sha256:2fe6926110e3d425c4b684a4379b7796fdc26ad7d16922ea1696c8e6ea7e920f"}, + {file = "safetensors-0.4.1-cp310-none-win_amd64.whl", hash = "sha256:a79e16222106b2f5edbca1b8185661477d8971b659a3c814cc6f15181a9b34c8"}, + {file = "safetensors-0.4.1-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:d93321eea0dd7e81b283e47a1d20dee6069165cc158286316d0d06d340de8fe8"}, + {file = "safetensors-0.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8ff8e41c8037db17de0ea2a23bc684f43eaf623be7d34906fe1ac10985b8365e"}, + {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39d36f1d88468a87c437a1bc27c502e71b6ca44c385a9117a9f9ba03a75cc9c6"}, + {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7ef010e9afcb4057fb6be3d0a0cfa07aac04fe97ef73fe4a23138d8522ba7c17"}, + {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b287304f2b2220d51ccb51fd857761e78bcffbeabe7b0238f8dc36f2edfd9542"}, + {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e09000b2599e1836314430f81a3884c66a5cbabdff5d9f175b5d560d4de38d78"}, + {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9c80ce0001efa16066358d2dd77993adc25f5a6c61850e4ad096a2232930bce"}, + {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:413e1f6ac248f7d1b755199a06635e70c3515493d3b41ba46063dec33aa2ebb7"}, + {file = "safetensors-0.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3ac139377cfe71ba04573f1cda66e663b7c3e95be850e9e6c2dd4b5984bd513"}, + {file = "safetensors-0.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:04157d008385bea66d12fe90844a80d4a76dc25ec5230b5bd9a630496d1b7c03"}, + {file = "safetensors-0.4.1-cp311-none-win32.whl", hash = "sha256:5f25297148ec665f0deb8bd67e9564634d8d6841041ab5393ccfe203379ea88b"}, + {file = "safetensors-0.4.1-cp311-none-win_amd64.whl", hash = "sha256:b2f8877990a72ff595507b80f4b69036a9a1986a641f8681adf3425d97d3d2a5"}, + {file = "safetensors-0.4.1-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:eb2c1da1cc39509d1a55620a5f4d14f8911c47a89c926a96e6f4876e864375a3"}, + {file = "safetensors-0.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:303d2c0415cf15a28f8d7f17379ea3c34c2b466119118a34edd9965983a1a8a6"}, + {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb4cb3e37a9b961ddd68e873b29fe9ab4a081e3703412e34aedd2b7a8e9cafd9"}, + {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ae5497adc68669db2fed7cb2dad81e6a6106e79c9a132da3efdb6af1db1014fa"}, + {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3b30abd0cddfe959d1daedf92edcd1b445521ebf7ddefc20860ed01486b33c90"}, + {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d784a98c492c751f228a4a894c3b8a092ff08b24e73b5568938c28b8c0e8f8df"}, + {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e57a5ab08b0ec7a7caf30d2ac79bb30c89168431aca4f8854464bb9461686925"}, + {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:edcf3121890b5f0616aa5a54683b1a5d2332037b970e507d6bb7841a3a596556"}, + {file = "safetensors-0.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fdb58dee173ef33634c3016c459d671ca12d11e6acf9db008261cbe58107e579"}, + {file = "safetensors-0.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:780dc21eb3fd32ddd0e8c904bdb0290f2454f4ac21ae71e94f9ce72db1900a5a"}, + {file = "safetensors-0.4.1-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:48901bd540f8a3c1791314bc5c8a170927bf7f6acddb75bf0a263d081a3637d4"}, + {file = "safetensors-0.4.1-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:3b0b7b2d5976fbed8a05e2bbdce5816a59e6902e9e7c7e07dc723637ed539787"}, + {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f69903ff49cb30b9227fb5d029bea276ea20d04b06803877a420c5b1b74c689"}, + {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0ddd050e01f3e843aa8c1c27bf68675b8a08e385d0045487af4d70418c3cb356"}, + {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a82bc2bd7a9a0e08239bdd6d7774d64121f136add93dfa344a2f1a6d7ef35fa"}, + {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6ace9e66a40f98a216ad661245782483cf79cf56eb2b112650bb904b0baa9db5"}, + {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82cbb8f4d022f2e94498cbefca900698b8ded3d4f85212f47da614001ff06652"}, + {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:791edc10a3c359a2f5f52d5cddab0df8a45107d91027d86c3d44e57162e5d934"}, + {file = "safetensors-0.4.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:83c2cfbe8c6304f0891e7bb378d56f66d2148972eeb5f747cd8a2246886f0d8c"}, + {file = "safetensors-0.4.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:04dd14f53f5500eb4c4149674216ba1000670efbcf4b1b5c2643eb244e7882ea"}, + {file = "safetensors-0.4.1-cp37-none-win32.whl", hash = "sha256:d5b3defa74f3723a388bfde2f5d488742bc4879682bd93267c09a3bcdf8f869b"}, + {file = "safetensors-0.4.1-cp37-none-win_amd64.whl", hash = "sha256:25a043cbb59d4f75e9dd87fdf5c009dd8830105a2c57ace49b72167dd9808111"}, + {file = "safetensors-0.4.1-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:3f6a520af7f2717c5ecba112041f2c8af1ca6480b97bf957aba81ed9642e654c"}, + {file = "safetensors-0.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c3807ac3b16288dffebb3474b555b56fe466baa677dfc16290dcd02dca1ab228"}, + {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b58ba13a9e82b4bc3fc221914f6ef237fe6c2adb13cede3ace64d1aacf49610"}, + {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dac4bb42f8679aadc59bd91a4c5a1784a758ad49d0912995945cd674089f628e"}, + {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:911b48dc09e321a194def3a7431662ff4f03646832f3a8915bbf0f449b8a5fcb"}, + {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:82571d20288c975c1b30b08deb9b1c3550f36b31191e1e81fae87669a92217d0"}, + {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da52ee0dc8ba03348ffceab767bd8230842fdf78f8a996e2a16445747143a778"}, + {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2536b11ce665834201072e9397404170f93f3be10cca9995b909f023a04501ee"}, + {file = "safetensors-0.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:998fbac99ca956c3a09fe07cc0b35fac26a521fa8865a690686d889f0ff4e4a6"}, + {file = "safetensors-0.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:845be0aafabf2a60c2d482d4e93023fecffe5e5443d801d7a7741bae9de41233"}, + {file = "safetensors-0.4.1-cp38-none-win32.whl", hash = "sha256:ce7a28bc8af685a69d7e869d09d3e180a275e3281e29cf5f1c7319e231932cc7"}, + {file = "safetensors-0.4.1-cp38-none-win_amd64.whl", hash = "sha256:e056fb9e22d118cc546107f97dc28b449d88274207dd28872bd668c86216e4f6"}, + {file = "safetensors-0.4.1-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:bdc0d039e44a727824639824090bd8869535f729878fa248addd3dc01db30eae"}, + {file = "safetensors-0.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3c1b1d510c7aba71504ece87bf393ea82638df56303e371e5e2cf09d18977dd7"}, + {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0bd0afd95c1e497f520e680ea01e0397c0868a3a3030e128438cf6e9e3fcd671"}, + {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f603bdd8deac6726d39f41688ed353c532dd53935234405d79e9eb53f152fbfb"}, + {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d8a85e3e47e0d4eebfaf9a58b40aa94f977a56050cb5598ad5396a9ee7c087c6"}, + {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0ccb5aa0f3be2727117e5631200fbb3a5b3a2b3757545a92647d6dd8be6658f"}, + {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d784938534e255473155e4d9f276ee69eb85455b6af1292172c731409bf9adee"}, + {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a257de175c254d39ccd6a21341cd62eb7373b05c1e618a78096a56a857e0c316"}, + {file = "safetensors-0.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6fd80f7794554091836d4d613d33a7d006e2b8d6ba014d06f97cebdfda744f64"}, + {file = "safetensors-0.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:35803201d980efcf964b75a0a2aee97fe5e9ecc5f3ad676b38fafdfe98e0620d"}, + {file = "safetensors-0.4.1-cp39-none-win32.whl", hash = "sha256:7ff8a36e0396776d3ed9a106fc9a9d7c55d4439ca9a056a24bf66d343041d3e6"}, + {file = "safetensors-0.4.1-cp39-none-win_amd64.whl", hash = "sha256:bfa2e20342b81921b98edba52f8deb68843fa9c95250739a56b52ceda5ea5c61"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:ae2d5a31cfb8a973a318f7c4d2cffe0bd1fe753cdf7bb41a1939d45a0a06f964"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1a45dbf03e8334d3a5dc93687d98b6dc422f5d04c7d519dac09b84a3c87dd7c6"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2297b359d91126c0f9d4fd17bae3cfa2fe3a048a6971b8db07db746ad92f850c"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bda3d98e2bcece388232cfc551ebf063b55bdb98f65ab54df397da30efc7dcc5"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f8934bdfd202ebd0697040a3dff40dd77bc4c5bbf3527ede0532f5e7fb4d970f"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:42c3710cec7e5c764c7999697516370bee39067de0aa089b7e2cfb97ac8c6b20"}, + {file = "safetensors-0.4.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:53134226053e56bd56e73f7db42596e7908ed79f3c9a1016e4c1dade593ac8e5"}, + {file = "safetensors-0.4.1-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:257d59e40a1b367cb544122e7451243d65b33c3f34d822a347f4eea6fdf97fdf"}, + {file = "safetensors-0.4.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d54c2f1826e790d1eb2d2512bfd0ee443f0206b423d6f27095057c7f18a0687"}, + {file = "safetensors-0.4.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:645b3f1138fce6e818e79d4128afa28f0657430764cc045419c1d069ff93f732"}, + {file = "safetensors-0.4.1-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e9a7ffb1e551c6df51d267f5a751f042b183df22690f6feceac8d27364fd51d7"}, + {file = "safetensors-0.4.1-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:44e230fbbe120de564b64f63ef3a8e6ff02840fa02849d9c443d56252a1646d4"}, + {file = "safetensors-0.4.1-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:9d16b3b2fcc6fca012c74bd01b5619c655194d3e3c13e4d4d0e446eefa39a463"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:5d95ea4d8b32233910734a904123bdd3979c137c461b905a5ed32511defc075f"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:dab431699b5d45e0ca043bc580651ce9583dda594e62e245b7497adb32e99809"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16d8bbb7344e39cb9d4762e85c21df94ebeb03edac923dd94bb9ed8c10eac070"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1faf5111c66a6ba91f85dff2e36edaaf36e6966172703159daeef330de4ddc7b"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:660ca1d8bff6c7bc7c6b30b9b32df74ef3ab668f5df42cefd7588f0d40feadcb"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ae2f67f04ed0bb2e56fd380a8bd3eef03f609df53f88b6f5c7e89c08e52aae00"}, + {file = "safetensors-0.4.1-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:c8ed5d2c04cdc1afc6b3c28d59580448ac07732c50d94c15e14670f9c473a2ce"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:2b6a2814278b6660261aa9a9aae524616de9f1ec364e3716d219b6ed8f91801f"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:3cfd1ca35eacc635f0eaa894e5c5ed83ffebd0f95cac298fd430014fa7323631"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4177b456c6b0c722d82429127b5beebdaf07149d265748e97e0a34ff0b3694c8"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:313e8472197bde54e3ec54a62df184c414582979da8f3916981b6a7954910a1b"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:fdb4adb76e21bad318210310590de61c9f4adcef77ee49b4a234f9dc48867869"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:1d568628e9c43ca15eb96c217da73737c9ccb07520fafd8a1eba3f2750614105"}, + {file = "safetensors-0.4.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:573b6023a55a2f28085fc0a84e196c779b6cbef4d9e73acea14c8094fee7686f"}, + {file = "safetensors-0.4.1.tar.gz", hash = "sha256:2304658e6ada81a5223225b4efe84748e760c46079bffedf7e321763cafb36c9"}, ] [package.extras] @@ -3816,13 +3825,13 @@ win32 = ["pywin32"] [[package]] name = "sentry-sdk" -version = "1.32.0" +version = "1.38.0" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = "*" files = [ - {file = "sentry-sdk-1.32.0.tar.gz", hash = "sha256:935e8fbd7787a3702457393b74b13d89a5afb67185bc0af85c00cb27cbd42e7c"}, - {file = "sentry_sdk-1.32.0-py2.py3-none-any.whl", hash = "sha256:eeb0b3550536f3bbc05bb1c7e0feb3a78d74acb43b607159a606ed2ec0a33a4d"}, + {file = "sentry-sdk-1.38.0.tar.gz", hash = "sha256:8feab81de6bbf64f53279b085bd3820e3e737403b0a0d9317f73a2c3374ae359"}, + {file = "sentry_sdk-1.38.0-py2.py3-none-any.whl", hash = "sha256:0017fa73b8ae2d4e57fd2522ee3df30453715b29d2692142793ec5d5f90b94a6"}, ] [package.dependencies] @@ -3961,17 +3970,17 @@ test = ["pytest"] [[package]] name = "setuptools" -version = "68.2.2" +version = "69.0.2" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-68.2.2-py3-none-any.whl", hash = "sha256:b454a35605876da60632df1a60f736524eb73cc47bbc9f3f1ef1b644de74fd2a"}, - {file = "setuptools-68.2.2.tar.gz", hash = "sha256:4ac1475276d2f1c48684874089fefcd83bd7162ddaafb81fac866ba0db282a87"}, + {file = "setuptools-69.0.2-py3-none-any.whl", hash = "sha256:1e8fdff6797d3865f37397be788a4e3cba233608e9b509382a2777d25ebde7f2"}, + {file = "setuptools-69.0.2.tar.gz", hash = "sha256:735896e78a4742605974de002ac60562d286fa8051a7e2299445e8e8fbb01aa6"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] @@ -4268,13 +4277,13 @@ doc = ["reno", "sphinx", "tornado (>=4.5)"] [[package]] name = "terminado" -version = "0.17.1" +version = "0.18.0" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "terminado-0.17.1-py3-none-any.whl", hash = "sha256:8650d44334eba354dd591129ca3124a6ba42c3d5b70df5051b6921d506fdaeae"}, - {file = "terminado-0.17.1.tar.gz", hash = "sha256:6ccbbcd3a4f8a25a5ec04991f39a0b8db52dfcd487ea0e578d977e6752380333"}, + {file = "terminado-0.18.0-py3-none-any.whl", hash = "sha256:87b0d96642d0fe5f5abd7783857b9cab167f221a39ff98e3b9619a788a3c0f2e"}, + {file = "terminado-0.18.0.tar.gz", hash = "sha256:1ea08a89b835dd1b8c0c900d92848147cef2537243361b2e3f4dc15df9b6fded"}, ] [package.dependencies] @@ -4285,6 +4294,7 @@ tornado = ">=6.1.0" [package.extras] docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["pre-commit", "pytest (>=7.0)", "pytest-timeout"] +typing = ["mypy (>=1.6,<2.0)", "traitlets (>=5.11.1)"] [[package]] name = "tinycss2" @@ -4306,113 +4316,113 @@ test = ["flake8", "isort", "pytest"] [[package]] name = "tokenizers" -version = "0.14.1" +version = "0.15.0" description = "" optional = false python-versions = ">=3.7" files = [ - {file = "tokenizers-0.14.1-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:04ec1134a18ede355a05641cdc7700f17280e01f69f2f315769f02f7e295cf1e"}, - {file = "tokenizers-0.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:638abedb39375f0ddce2de536fc9c976639b2d1b7202d715c2e7a25f0ebfd091"}, - {file = "tokenizers-0.14.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:901635098565773a44f74068639d265f19deaaca47ea77b428fd9bee13a61d87"}, - {file = "tokenizers-0.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:72e95184bf5b9a4c08153ed07c16c130ff174835c9a1e6ee2b311be758c8b3ef"}, - {file = "tokenizers-0.14.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ebefbc26ccff5e96ae7d40772172e7310174f9aa3683d2870a1882313ec3a4d5"}, - {file = "tokenizers-0.14.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d3a6330c9f1deda22873e8b4ac849cc06d3ff33d60b3217ac0bb397b541e1509"}, - {file = "tokenizers-0.14.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6cba7483ba45600346a35c466bde32327b108575022f73c35a0f7170b5a71ae2"}, - {file = "tokenizers-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:60fec380778d75cbb492f14ca974f11f37b41d53c057b9c8ba213315b86e1f84"}, - {file = "tokenizers-0.14.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:930c19b699dd7e1077eac98967adc2fe5f0b104bd96cc1f26778ab82b31ceb24"}, - {file = "tokenizers-0.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a1e30a13376db5329570e09b14c8eb36c017909ed7e88591ca3aa81f3c7d6f32"}, - {file = "tokenizers-0.14.1-cp310-none-win32.whl", hash = "sha256:370b5b86da9bddbe65fa08711f0e8ffdf8b0036558178d1a31dfcb44efcde72a"}, - {file = "tokenizers-0.14.1-cp310-none-win_amd64.whl", hash = "sha256:c2c659f2106b6d154f118ad1b700e68148c46c59b720f04867b1fc5f26a85060"}, - {file = "tokenizers-0.14.1-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:00df4c5bf25c153b432b98689609b426ae701a44f3d8074dcb619f410bc2a870"}, - {file = "tokenizers-0.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fee553657dcdb7e73df8823c49e8611457ba46e9d7026b7e9c44820c08c327c3"}, - {file = "tokenizers-0.14.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a480bd902e327dfcaa52b7dd14fdc71e7aa45d73a3d6e41e028a75891d2823cf"}, - {file = "tokenizers-0.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e448b2be0430ab839cf7954715c39d6f34ff6cf2b49393f336283b7a59f485af"}, - {file = "tokenizers-0.14.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c11444984aecd342f0cf160c3320288edeb1763871fbb560ed466654b2a7016c"}, - {file = "tokenizers-0.14.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bfe164a1c72c6be3c5c26753c6c412f81412f4dae0d7d06371e0b396a9cc0fc9"}, - {file = "tokenizers-0.14.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:72d9967fb1f927542cfb5347207fde01b29f25c9bb8cbc7ced280decfa015983"}, - {file = "tokenizers-0.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37cc955c84ec67c2d11183d372044399342b20a1fa447b7a33040f4889bba318"}, - {file = "tokenizers-0.14.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:db96cf092d86d4cb543daa9148e299011e0a40770380bb78333b9fd700586fcb"}, - {file = "tokenizers-0.14.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c84d3cb1349936c2b96ca6175b50f5a9518170bffd76464219ee0ea6022a64a7"}, - {file = "tokenizers-0.14.1-cp311-none-win32.whl", hash = "sha256:8db3a6f3d430ac3dc3793c53fa8e5e665c23ba359484d365a191027ad8b65a30"}, - {file = "tokenizers-0.14.1-cp311-none-win_amd64.whl", hash = "sha256:c65d76052561c60e17cb4fa289885ed00a9995d59e97019fac2138bd45142057"}, - {file = "tokenizers-0.14.1-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:c375161b588982be381c43eb7158c250f430793d0f708ce379a0f196164c6778"}, - {file = "tokenizers-0.14.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:50f03d2330a153a9114c2429061137bd323736059f384de8348d7cb1ca1baa15"}, - {file = "tokenizers-0.14.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0c8ee283b249c3c3c201c41bc23adc3be2514ae4121eacdb5c5250a461eaa8c6"}, - {file = "tokenizers-0.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e9f27399b8d50c5d3f08f0aae961bcc66a1dead1cd0ae9401e4c2a43a623322a"}, - {file = "tokenizers-0.14.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:89cbeec7e9d5d8773ec4779c64e3cbcbff53d234ca6ad7b1a3736588003bba48"}, - {file = "tokenizers-0.14.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:08e55920b453c30b46d58accc68a38e8e7488d0c03babfdb29c55d3f39dd2052"}, - {file = "tokenizers-0.14.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:91d32bd1056c0e83a0f90e4ffa213c25096b2d8b9f0e2d172a45f138c7d8c081"}, - {file = "tokenizers-0.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44f1748035c36c939848c935715bde41734d9249ab7b844ff9bfbe984be8952c"}, - {file = "tokenizers-0.14.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1ff516d129f01bb7a4aa95bc6aae88e4d86dd63bfc2d57db9302c2624d1be7cb"}, - {file = "tokenizers-0.14.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:acfc8db61c6e919d932448cc7985b85e330c8d745528e12fce6e62d40d268bce"}, - {file = "tokenizers-0.14.1-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:ba336bc9107acbc1da2ad30967df7b2db93448ca66538ad86aa1fbb91116f631"}, - {file = "tokenizers-0.14.1-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:f77371b5030e53f8bf92197640af437539e3bba1bc8342b97888c8e26567bfdc"}, - {file = "tokenizers-0.14.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:d72d25c57a9c814240802d188ff0a808b701e2dd2bf1c64721c7088ceeeb1ed7"}, - {file = "tokenizers-0.14.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caf0df8657277e32671aa8a4d3cc05f2050ab19d9b49447f2265304168e9032c"}, - {file = "tokenizers-0.14.1-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cb3c6bc6e599e46a26ad559ad5dec260ffdf705663cc9b894033d64a69314e86"}, - {file = "tokenizers-0.14.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8cf2fcdc2368df4317e05571e33810eeed24cd594acc9dfc9788b21dac6b3a8"}, - {file = "tokenizers-0.14.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f475d5eda41d2ed51ca775a07c80529a923dd759fcff7abf03ccdd83d9f7564e"}, - {file = "tokenizers-0.14.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cce4d1a97a7eb2253b5d3f29f4a478d8c37ba0303ea34024eb9e65506d4209f8"}, - {file = "tokenizers-0.14.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:ff66577ae55114f7d0f6aa0d4d335f27cae96bf245962a745b718ec887bbe7eb"}, - {file = "tokenizers-0.14.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a687099e085f5162e5b88b3402adb6c2b41046180c015c5075c9504440b6e971"}, - {file = "tokenizers-0.14.1-cp37-none-win32.whl", hash = "sha256:49f5336b82e315a33bef1025d247ca08d95719715b29e33f0e9e8cf15ff1dfb6"}, - {file = "tokenizers-0.14.1-cp37-none-win_amd64.whl", hash = "sha256:117c8da60d1bd95a6df2692926f36de7971baa1d89ff702fae47b6689a4465ad"}, - {file = "tokenizers-0.14.1-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:01d2bd5935642de22a6c6778bb2307f9949cd6eaeeb5c77f9b98f0060b69f0db"}, - {file = "tokenizers-0.14.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b05ec04132394c20bd6bcb692d557a8eb8ab1bac1646d28e49c67c00907d17c8"}, - {file = "tokenizers-0.14.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:7d9025b185465d9d18679406f6f394850347d5ed2681efc203539d800f36f459"}, - {file = "tokenizers-0.14.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2539831838ab5393f78a893d7bbf27d5c36e43baf77e91dc9992922b2b97e09d"}, - {file = "tokenizers-0.14.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ec8f46d533092d8e20bc742c47918cbe24b8641dbfbbcb83177c5de3c9d4decb"}, - {file = "tokenizers-0.14.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8b019c4810903fdea3b230f358b9d27377c0f38454778b607676c9e1b57d14b7"}, - {file = "tokenizers-0.14.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e8984114fd83ed3913d89526c992395920930c9620a2feee61faf035f41d7b9a"}, - {file = "tokenizers-0.14.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11284b32f0036fe7ef4b8b00201dda79c00f3fcea173bc0e5c599e09c937ab0f"}, - {file = "tokenizers-0.14.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:53614f44f36917282a583180e402105bc63d61d1aca067d51cb7f051eb489901"}, - {file = "tokenizers-0.14.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e3b6082e9532309727273443c8943bb9558d52e36788b246aa278bda7c642116"}, - {file = "tokenizers-0.14.1-cp38-none-win32.whl", hash = "sha256:7560fca3e17a6bc876d20cd825d7721c101fa2b1cd0bfa0abf9a2e781e49b37b"}, - {file = "tokenizers-0.14.1-cp38-none-win_amd64.whl", hash = "sha256:c318a5acb429ca38f632577754235140bbb8c5a27faca1c51b43fbf575596e34"}, - {file = "tokenizers-0.14.1-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:b886e0f5c72aa4249c609c24b9610a9ca83fd963cbb5066b19302723ea505279"}, - {file = "tokenizers-0.14.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f522f28c88a0d5b2f9e895cf405dd594cd518e99d61905406aec74d30eb6383b"}, - {file = "tokenizers-0.14.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5bef76c4d9329913cef2fe79ce1f4dab98f77fa4887e5f0420ffc9386941de32"}, - {file = "tokenizers-0.14.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59c7df2103052b30b7c76d4fa8251326c9f82689578a912698a127dc1737f43e"}, - {file = "tokenizers-0.14.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:232445e7b85255ccfe68dfd42185db8a3f3349b34ad7068404856c4a5f67c355"}, - {file = "tokenizers-0.14.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8e63781da85aa8948864970e529af10abc4084a990d30850c41bbdb5f83eee45"}, - {file = "tokenizers-0.14.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5760a831c0f3c6d3229b50ef3fafa4c164ec99d7e8c2237fe144e67a9d33b120"}, - {file = "tokenizers-0.14.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c84b456ff8525ec3ff09762e32ccc27888d036dcd0ba2883e1db491e164dd725"}, - {file = "tokenizers-0.14.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:463ee5f3afbfec29cbf5652752c9d1032bdad63daf48bb8cb9970064cc81d5f9"}, - {file = "tokenizers-0.14.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ee6b63aecf929a7bcf885bdc8a8aec96c43bc4442f63fe8c6d48f24fc992b05b"}, - {file = "tokenizers-0.14.1-cp39-none-win32.whl", hash = "sha256:aae42798ba1da3bc1572b2048fe42e61dd6bacced2b424cb0f5572c5432f79c2"}, - {file = "tokenizers-0.14.1-cp39-none-win_amd64.whl", hash = "sha256:68c4699147dded6926a3d2c2f948d435d54d027f69909e0ef3c6587933723ed2"}, - {file = "tokenizers-0.14.1-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:5f9afdcf701a1aa3c41e0e748c152d2162434d61639a1e5d8523ecf60ae35aea"}, - {file = "tokenizers-0.14.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:6859d81243cd09854be9054aca3ecab14a2dee5b3c9f6d7ef12061d478ca0c57"}, - {file = "tokenizers-0.14.1-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:7975178f9478ccedcf613332d5d6f37b67c74ef4e2e47e0c965597506b921f04"}, - {file = "tokenizers-0.14.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ce2f0ff2e5f12ac5bebaa690606395725239265d7ffa35f35c243a379316297"}, - {file = "tokenizers-0.14.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c7cfc3d42e81cda802f93aa9e92caf79feaa1711426e28ce620560b8aaf5e4d"}, - {file = "tokenizers-0.14.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:67d3adff654dc7f7c7091dd259b3b847fe119c08d0bda61db91e2ea2b61c38c0"}, - {file = "tokenizers-0.14.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:956729b7dd599020e57133fb95b777e4f81ee069ff0a70e80f6eeac82658972f"}, - {file = "tokenizers-0.14.1-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:fe2ea1177146a7ab345ab61e90a490eeea25d5f063e1cb9d4eb1425b169b64d7"}, - {file = "tokenizers-0.14.1-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9930f31f603ecc6ea54d5c6dfa299f926ab3e921f72f94babcb02598c32b57c6"}, - {file = "tokenizers-0.14.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d49567a2754e9991c05c2b5a7e6650b56e24365b7cab504558e58033dcf0edc4"}, - {file = "tokenizers-0.14.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3678be5db330726f19c1949d8ae1b845a02eeb2a2e1d5a8bb8eaa82087ae25c1"}, - {file = "tokenizers-0.14.1-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:42b180ed1bec58ab9bdc65d406577e0c0fb7241b74b8c032846073c7743c9f86"}, - {file = "tokenizers-0.14.1-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:319e4367596fb0d52be645b3de1616faf0fadaf28507ce1c7595bebd9b4c402c"}, - {file = "tokenizers-0.14.1-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:2cda65b689aec63b7c76a77f43a08044fa90bbc6ad9849267cedfee9795913f3"}, - {file = "tokenizers-0.14.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:ca0bfc79b27d84fcb7fa09339b2ee39077896738d9a30ff99c0332376e985072"}, - {file = "tokenizers-0.14.1-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a7093767e070269e22e2c5f845e46510304f124c32d2cd249633c0f27eb29d86"}, - {file = "tokenizers-0.14.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad759ba39cd32c2c2247864d02c84ea5883b5f6cc6a4ee0c95602a3dde52268f"}, - {file = "tokenizers-0.14.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26fee36a6d8f2bd9464f3566b95e3e3fb7fd7dad723f775c500aac8204ec98c6"}, - {file = "tokenizers-0.14.1-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d091c62cb7abbd32e527a85c41f7c8eb4526a926251891fc4ecbe5f974142ffb"}, - {file = "tokenizers-0.14.1-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ca304402ea66d58f99c05aa3d7a6052faea61e5a8313b94f6bc36fbf27960e2d"}, - {file = "tokenizers-0.14.1-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:102f118fa9b720b93c3217c1e239ed7bc1ae1e8dbfe9b4983a4f2d7b4ce6f2ec"}, - {file = "tokenizers-0.14.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:df4f058e96e8b467b7742e5dba7564255cd482d3c1e6cf81f8cb683bb0433340"}, - {file = "tokenizers-0.14.1-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:040ee44efc1806900de72b13c1c3036154077d9cde189c9a7e7a50bbbdcbf39f"}, - {file = "tokenizers-0.14.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7618b84118ae704f7fa23c4a190bd80fc605671841a4427d5ca14b9b8d9ec1a3"}, - {file = "tokenizers-0.14.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ecdfe9736c4a73343f629586016a137a10faed1a29c6dc699d8ab20c2d3cf64"}, - {file = "tokenizers-0.14.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:92c34de04fec7f4ff95f7667d4eb085c4e4db46c31ef44c3d35c38df128430da"}, - {file = "tokenizers-0.14.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:628b654ba555b2ba9111c0936d558b14bfc9d5f57b8c323b02fc846036b38b2f"}, - {file = "tokenizers-0.14.1.tar.gz", hash = "sha256:ea3b3f8908a9a5b9d6fc632b5f012ece7240031c44c6d4764809f33736534166"}, + {file = "tokenizers-0.15.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:cd3cd0299aaa312cd2988957598f80becd04d5a07338741eca076057a2b37d6e"}, + {file = "tokenizers-0.15.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a922c492c721744ee175f15b91704be2d305569d25f0547c77cd6c9f210f9dc"}, + {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:331dd786d02fc38698f835fff61c99480f98b73ce75a4c65bd110c9af5e4609a"}, + {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88dd0961c437d413ab027f8b115350c121d49902cfbadf08bb8f634b15fa1814"}, + {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6fdcc55339df7761cd52e1fbe8185d3b3963bc9e3f3545faa6c84f9e8818259a"}, + {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1480b0051d8ab5408e8e4db2dc832f7082ea24aa0722c427bde2418c6f3bd07"}, + {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9855e6c258918f9cf62792d4f6ddfa6c56dccd8c8118640f867f6393ecaf8bd7"}, + {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de9529fe75efcd54ba8d516aa725e1851df9199f0669b665c55e90df08f5af86"}, + {file = "tokenizers-0.15.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:8edcc90a36eab0705fe9121d6c77c6e42eeef25c7399864fd57dfb27173060bf"}, + {file = "tokenizers-0.15.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ae17884aafb3e94f34fb7cfedc29054f5f54e142475ebf8a265a4e388fee3f8b"}, + {file = "tokenizers-0.15.0-cp310-none-win32.whl", hash = "sha256:9a3241acdc9b44cff6e95c4a55b9be943ef3658f8edb3686034d353734adba05"}, + {file = "tokenizers-0.15.0-cp310-none-win_amd64.whl", hash = "sha256:4b31807cb393d6ea31926b307911c89a1209d5e27629aa79553d1599c8ffdefe"}, + {file = "tokenizers-0.15.0-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:af7e9be8c05d30bb137b9fd20f9d99354816599e5fd3d58a4b1e28ba3b36171f"}, + {file = "tokenizers-0.15.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c3d7343fa562ea29661783344a2d83662db0d3d17a6fa6a403cac8e512d2d9fd"}, + {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:32371008788aeeb0309a9244809a23e4c0259625e6b74a103700f6421373f395"}, + {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca9db64c7c9954fbae698884c5bb089764edc549731e5f9b7fa1dd4e4d78d77f"}, + {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dbed5944c31195514669cf6381a0d8d47f164943000d10f93d6d02f0d45c25e0"}, + {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aab16c4a26d351d63e965b0c792f5da7227a37b69a6dc6d922ff70aa595b1b0c"}, + {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3c2b60b12fdd310bf85ce5d7d3f823456b9b65eed30f5438dd7761879c495983"}, + {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0344d6602740e44054a9e5bbe9775a5e149c4dddaff15959bb07dcce95a5a859"}, + {file = "tokenizers-0.15.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4525f6997d81d9b6d9140088f4f5131f6627e4c960c2c87d0695ae7304233fc3"}, + {file = "tokenizers-0.15.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:65975094fef8cc68919644936764efd2ce98cf1bacbe8db2687155d2b0625bee"}, + {file = "tokenizers-0.15.0-cp311-none-win32.whl", hash = "sha256:ff5d2159c5d93015f5a4542aac6c315506df31853123aa39042672031768c301"}, + {file = "tokenizers-0.15.0-cp311-none-win_amd64.whl", hash = "sha256:2dd681b53cf615e60a31a115a3fda3980e543d25ca183797f797a6c3600788a3"}, + {file = "tokenizers-0.15.0-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:c9cce6ee149a3d703f86877bc2a6d997e34874b2d5a2d7839e36b2273f31d3d9"}, + {file = "tokenizers-0.15.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4a0a94bc3370e6f1cc8a07a8ae867ce13b7c1b4291432a773931a61f256d44ea"}, + {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:309cfcccfc7e502cb1f1de2c9c1c94680082a65bfd3a912d5a5b2c90c677eb60"}, + {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8413e994dd7d875ab13009127fc85633916c71213917daf64962bafd488f15dc"}, + {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d0ebf9430f901dbdc3dcb06b493ff24a3644c9f88c08e6a1d6d0ae2228b9b818"}, + {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:10361e9c7864b22dd791ec5126327f6c9292fb1d23481d4895780688d5e298ac"}, + {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:babe42635b8a604c594bdc56d205755f73414fce17ba8479d142a963a6c25cbc"}, + {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3768829861e964c7a4556f5f23307fce6a23872c2ebf030eb9822dbbbf7e9b2a"}, + {file = "tokenizers-0.15.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9c91588a630adc88065e1c03ac6831e3e2112558869b9ebcb2b8afd8a14c944d"}, + {file = "tokenizers-0.15.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:77606994e793ca54ecf3a3619adc8a906a28ca223d9354b38df41cb8766a0ed6"}, + {file = "tokenizers-0.15.0-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:6fe143939f3b596681922b2df12a591a5b010e7dcfbee2202482cd0c1c2f2459"}, + {file = "tokenizers-0.15.0-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:b7bee0f1795e3e3561e9a557061b1539e5255b8221e3f928f58100282407e090"}, + {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5d37e7f4439b4c46192ab4f2ff38ab815e4420f153caa13dec9272ef14403d34"}, + {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caadf255cf7f951b38d10097836d1f3bcff4aeaaffadfdf748bab780bf5bff95"}, + {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:05accb9162bf711a941b1460b743d62fec61c160daf25e53c5eea52c74d77814"}, + {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:26a2ef890740127cb115ee5260878f4a677e36a12831795fd7e85887c53b430b"}, + {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e54c5f26df14913620046b33e822cb3bcd091a332a55230c0e63cc77135e2169"}, + {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:669b8ed653a578bcff919566631156f5da3aab84c66f3c0b11a6281e8b4731c7"}, + {file = "tokenizers-0.15.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:0ea480d943297df26f06f508dab6e012b07f42bf3dffdd36e70799368a5f5229"}, + {file = "tokenizers-0.15.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:bc80a0a565ebfc7cd89de7dd581da8c2b3238addfca6280572d27d763f135f2f"}, + {file = "tokenizers-0.15.0-cp37-none-win32.whl", hash = "sha256:cdd945e678bbdf4517d5d8de66578a5030aeefecdb46f5320b034de9cad8d4dd"}, + {file = "tokenizers-0.15.0-cp37-none-win_amd64.whl", hash = "sha256:1ab96ab7dc706e002c32b2ea211a94c1c04b4f4de48354728c3a6e22401af322"}, + {file = "tokenizers-0.15.0-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:f21c9eb71c9a671e2a42f18b456a3d118e50c7f0fc4dd9fa8f4eb727fea529bf"}, + {file = "tokenizers-0.15.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2a5f4543a35889679fc3052086e69e81880b2a5a28ff2a52c5a604be94b77a3f"}, + {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f8aa81afec893e952bd39692b2d9ef60575ed8c86fce1fd876a06d2e73e82dca"}, + {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1574a5a4af22c3def93fe8fe4adcc90a39bf5797ed01686a4c46d1c3bc677d2f"}, + {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7c7982fd0ec9e9122d03b209dac48cebfea3de0479335100ef379a9a959b9a5a"}, + {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8d16b647032df2ce2c1f9097236e046ea9fedd969b25637b9d5d734d78aa53b"}, + {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b3cdf29e6f9653da330515dc8fa414be5a93aae79e57f8acc50d4028dd843edf"}, + {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7286f3df10de840867372e3e64b99ef58c677210e3ceb653cd0e740a5c53fe78"}, + {file = "tokenizers-0.15.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:aabc83028baa5a36ce7a94e7659250f0309c47fa4a639e5c2c38e6d5ea0de564"}, + {file = "tokenizers-0.15.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:72f78b0e0e276b1fc14a672fa73f3acca034ba8db4e782124a2996734a9ba9cf"}, + {file = "tokenizers-0.15.0-cp38-none-win32.whl", hash = "sha256:9680b0ecc26e7e42f16680c1aa62e924d58d1c2dd992707081cc10a374896ea2"}, + {file = "tokenizers-0.15.0-cp38-none-win_amd64.whl", hash = "sha256:f17cbd88dab695911cbdd385a5a7e3709cc61dff982351f5d1b5939f074a2466"}, + {file = "tokenizers-0.15.0-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:3661862df7382c5eb23ac4fbf7c75e69b02dc4f5784e4c5a734db406b5b24596"}, + {file = "tokenizers-0.15.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c3045d191dad49647f5a5039738ecf1c77087945c7a295f7bcf051c37067e883"}, + {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a9fcaad9ab0801f14457d7c820d9f246b5ab590c407fc6b073819b1573097aa7"}, + {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a79f17027f24fe9485701c8dbb269b9c713954ec3bdc1e7075a66086c0c0cd3c"}, + {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:01a3aa332abc4bee7640563949fcfedca4de8f52691b3b70f2fc6ca71bfc0f4e"}, + {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:05b83896a893cdfedad8785250daa3ba9f0504848323471524d4783d7291661e"}, + {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cbbf2489fcf25d809731ba2744ff278dd07d9eb3f8b7482726bd6cae607073a4"}, + {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab806ad521a5e9de38078b7add97589c313915f6f5fec6b2f9f289d14d607bd6"}, + {file = "tokenizers-0.15.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4a522612d5c88a41563e3463226af64e2fa00629f65cdcc501d1995dd25d23f5"}, + {file = "tokenizers-0.15.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e58a38c4e6075810bdfb861d9c005236a72a152ebc7005941cc90d1bbf16aca9"}, + {file = "tokenizers-0.15.0-cp39-none-win32.whl", hash = "sha256:b8034f1041fd2bd2b84ff9f4dc4ae2e1c3b71606820a9cd5c562ebd291a396d1"}, + {file = "tokenizers-0.15.0-cp39-none-win_amd64.whl", hash = "sha256:edde9aa964145d528d0e0dbf14f244b8a85ebf276fb76869bc02e2530fa37a96"}, + {file = "tokenizers-0.15.0-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:309445d10d442b7521b98083dc9f0b5df14eca69dbbfebeb98d781ee2cef5d30"}, + {file = "tokenizers-0.15.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d3125a6499226d4d48efc54f7498886b94c418e93a205b673bc59364eecf0804"}, + {file = "tokenizers-0.15.0-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:ed56ddf0d54877bb9c6d885177db79b41576e61b5ef6defeb579dcb803c04ad5"}, + {file = "tokenizers-0.15.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b22cd714706cc5b18992a232b023f736e539495f5cc61d2d28d176e55046f6c"}, + {file = "tokenizers-0.15.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fac2719b1e9bc8e8e7f6599b99d0a8e24f33d023eb8ef644c0366a596f0aa926"}, + {file = "tokenizers-0.15.0-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:85ddae17570ec7e5bfaf51ffa78d044f444a8693e1316e1087ee6150596897ee"}, + {file = "tokenizers-0.15.0-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:76f1bed992e396bf6f83e3df97b64ff47885e45e8365f8983afed8556a0bc51f"}, + {file = "tokenizers-0.15.0-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:3bb0f4df6dce41a1c7482087b60d18c372ef4463cb99aa8195100fcd41e0fd64"}, + {file = "tokenizers-0.15.0-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:22c27672c27a059a5f39ff4e49feed8c7f2e1525577c8a7e3978bd428eb5869d"}, + {file = "tokenizers-0.15.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78104f5d035c9991f92831fc0efe9e64a05d4032194f2a69f67aaa05a4d75bbb"}, + {file = "tokenizers-0.15.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a40b73dc19d82c3e3ffb40abdaacca8fbc95eeb26c66b7f9f860aebc07a73998"}, + {file = "tokenizers-0.15.0-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d801d1368188c74552cd779b1286e67cb9fd96f4c57a9f9a2a09b6def9e1ab37"}, + {file = "tokenizers-0.15.0-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:82641ffb13a4da1293fcc9f437d457647e60ed0385a9216cd135953778b3f0a1"}, + {file = "tokenizers-0.15.0-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:160f9d1810f2c18fffa94aa98bf17632f6bd2dabc67fcb01a698ca80c37d52ee"}, + {file = "tokenizers-0.15.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:8d7d6eea831ed435fdeeb9bcd26476226401d7309d115a710c65da4088841948"}, + {file = "tokenizers-0.15.0-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f6456bec6c557d63d8ec0023758c32f589e1889ed03c055702e84ce275488bed"}, + {file = "tokenizers-0.15.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1eef39a502fad3bf104b9e1906b4fb0cee20e44e755e51df9a98f8922c3bf6d4"}, + {file = "tokenizers-0.15.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1e4664c5b797e093c19b794bbecc19d2367e782b4a577d8b7c1821db5dc150d"}, + {file = "tokenizers-0.15.0-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ca003fb5f3995ff5cf676db6681b8ea5d54d3b30bea36af1120e78ee1a4a4cdf"}, + {file = "tokenizers-0.15.0-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:7f17363141eb0c53752c89e10650b85ef059a52765d0802ba9613dbd2d21d425"}, + {file = "tokenizers-0.15.0-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:8a765db05581c7d7e1280170f2888cda351760d196cc059c37ea96f121125799"}, + {file = "tokenizers-0.15.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:2a0dd641a72604486cd7302dd8f87a12c8a9b45e1755e47d2682733f097c1af5"}, + {file = "tokenizers-0.15.0-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0a1a3c973e4dc97797fc19e9f11546c95278ffc55c4492acb742f69e035490bc"}, + {file = "tokenizers-0.15.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4fab75642aae4e604e729d6f78e0addb9d7e7d49e28c8f4d16b24da278e5263"}, + {file = "tokenizers-0.15.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65f80be77f6327a86d8fd35a4467adcfe6174c159b4ab52a1a8dd4c6f2d7d9e1"}, + {file = "tokenizers-0.15.0-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:a8da7533dbe66b88afd430c56a2f2ce1fd82e2681868f857da38eeb3191d7498"}, + {file = "tokenizers-0.15.0-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:fa8eb4584fc6cbe6a84d7a7864be3ed28e23e9fd2146aa8ef1814d579df91958"}, + {file = "tokenizers-0.15.0.tar.gz", hash = "sha256:10c7e6e7b4cabd757da59e93f5f8d1126291d16f8b54f28510825ef56a3e5d0e"}, ] [package.dependencies] -huggingface_hub = ">=0.16.4,<0.18" +huggingface_hub = ">=0.16.4,<1.0" [package.extras] dev = ["tokenizers[testing]"] @@ -4432,13 +4442,13 @@ files = [ [[package]] name = "tomlkit" -version = "0.12.1" +version = "0.12.3" description = "Style preserving TOML library" optional = false python-versions = ">=3.7" files = [ - {file = "tomlkit-0.12.1-py3-none-any.whl", hash = "sha256:712cbd236609acc6a3e2e97253dfc52d4c2082982a88f61b640ecf0817eab899"}, - {file = "tomlkit-0.12.1.tar.gz", hash = "sha256:38e1ff8edb991273ec9f6181244a6a391ac30e9f5098e7535640ea6be97a7c86"}, + {file = "tomlkit-0.12.3-py3-none-any.whl", hash = "sha256:b0a645a9156dc7cb5d3a1f0d4bab66db287fcb8e0430bdd4664a095ea16414ba"}, + {file = "tomlkit-0.12.3.tar.gz", hash = "sha256:75baf5012d06501f07bee5bf8e801b9f343e7aac5a92581f20f80ce632e6b5a4"}, ] [[package]] @@ -4496,22 +4506,22 @@ opt-einsum = ["opt-einsum (>=3.3)"] [[package]] name = "tornado" -version = "6.3.3" +version = "6.4" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." optional = false python-versions = ">= 3.8" files = [ - {file = "tornado-6.3.3-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:502fba735c84450974fec147340016ad928d29f1e91f49be168c0a4c18181e1d"}, - {file = "tornado-6.3.3-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:805d507b1f588320c26f7f097108eb4023bbaa984d63176d1652e184ba24270a"}, - {file = "tornado-6.3.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1bd19ca6c16882e4d37368e0152f99c099bad93e0950ce55e71daed74045908f"}, - {file = "tornado-6.3.3-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ac51f42808cca9b3613f51ffe2a965c8525cb1b00b7b2d56828b8045354f76a"}, - {file = "tornado-6.3.3-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71a8db65160a3c55d61839b7302a9a400074c9c753040455494e2af74e2501f2"}, - {file = "tornado-6.3.3-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:ceb917a50cd35882b57600709dd5421a418c29ddc852da8bcdab1f0db33406b0"}, - {file = "tornado-6.3.3-cp38-abi3-musllinux_1_1_i686.whl", hash = "sha256:7d01abc57ea0dbb51ddfed477dfe22719d376119844e33c661d873bf9c0e4a16"}, - {file = "tornado-6.3.3-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:9dc4444c0defcd3929d5c1eb5706cbe1b116e762ff3e0deca8b715d14bf6ec17"}, - {file = "tornado-6.3.3-cp38-abi3-win32.whl", hash = "sha256:65ceca9500383fbdf33a98c0087cb975b2ef3bfb874cb35b8de8740cf7f41bd3"}, - {file = "tornado-6.3.3-cp38-abi3-win_amd64.whl", hash = "sha256:22d3c2fa10b5793da13c807e6fc38ff49a4f6e1e3868b0a6f4164768bb8e20f5"}, - {file = "tornado-6.3.3.tar.gz", hash = "sha256:e7d8db41c0181c80d76c982aacc442c0783a2c54d6400fe028954201a2e032fe"}, + {file = "tornado-6.4-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:02ccefc7d8211e5a7f9e8bc3f9e5b0ad6262ba2fbb683a6443ecc804e5224ce0"}, + {file = "tornado-6.4-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:27787de946a9cffd63ce5814c33f734c627a87072ec7eed71f7fc4417bb16263"}, + {file = "tornado-6.4-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7894c581ecdcf91666a0912f18ce5e757213999e183ebfc2c3fdbf4d5bd764e"}, + {file = "tornado-6.4-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e43bc2e5370a6a8e413e1e1cd0c91bedc5bd62a74a532371042a18ef19e10579"}, + {file = "tornado-6.4-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0251554cdd50b4b44362f73ad5ba7126fc5b2c2895cc62b14a1c2d7ea32f212"}, + {file = "tornado-6.4-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:fd03192e287fbd0899dd8f81c6fb9cbbc69194d2074b38f384cb6fa72b80e9c2"}, + {file = "tornado-6.4-cp38-abi3-musllinux_1_1_i686.whl", hash = "sha256:88b84956273fbd73420e6d4b8d5ccbe913c65d31351b4c004ae362eba06e1f78"}, + {file = "tornado-6.4-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:71ddfc23a0e03ef2df1c1397d859868d158c8276a0603b96cf86892bff58149f"}, + {file = "tornado-6.4-cp38-abi3-win32.whl", hash = "sha256:6f8a6c77900f5ae93d8b4ae1196472d0ccc2775cc1dfdc9e7727889145c45052"}, + {file = "tornado-6.4-cp38-abi3-win_amd64.whl", hash = "sha256:10aeaa8006333433da48dec9fe417877f8bcc21f48dda8d661ae79da357b2a63"}, + {file = "tornado-6.4.tar.gz", hash = "sha256:72291fa6e6bc84e626589f1c29d90a5a6d593ef5ae68052ee2ef000dfd273dee"}, ] [[package]] @@ -4536,28 +4546,28 @@ telegram = ["requests"] [[package]] name = "traitlets" -version = "5.12.0" +version = "5.14.0" description = "Traitlets Python configuration system" optional = false python-versions = ">=3.8" files = [ - {file = "traitlets-5.12.0-py3-none-any.whl", hash = "sha256:81539f07f7aebcde2e4b5ab76727f53eabf18ad155c6ed7979a681411602fa47"}, - {file = "traitlets-5.12.0.tar.gz", hash = "sha256:833273bf645d8ce31dcb613c56999e2e055b1ffe6d09168a164bcd91c36d5d35"}, + {file = "traitlets-5.14.0-py3-none-any.whl", hash = "sha256:f14949d23829023013c47df20b4a76ccd1a85effb786dc060f34de7948361b33"}, + {file = "traitlets-5.14.0.tar.gz", hash = "sha256:fcdaa8ac49c04dfa0ed3ee3384ef6dfdb5d6f3741502be247279407679296772"}, ] [package.extras] docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] -test = ["argcomplete (>=3.0.3)", "mypy (>=1.6.0)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"] +test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"] [[package]] name = "transformers" -version = "4.34.1" +version = "4.35.2" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.34.1-py3-none-any.whl", hash = "sha256:d06ac09151d7b845e4a4acd6b143a591d946031ee67b4cbb20693b241920ffc0"}, - {file = "transformers-4.34.1.tar.gz", hash = "sha256:1d0258d5a18063b66005bbe1e3276ec5943d9ab4ab47f020db1fd485cc40ea22"}, + {file = "transformers-4.35.2-py3-none-any.whl", hash = "sha256:9dfa76f8692379544ead84d98f537be01cd1070de75c74efb13abcbc938fbe2f"}, + {file = "transformers-4.35.2.tar.gz", hash = "sha256:2d125e197d77b0cdb6c9201df9fa7e2101493272e448b9fba9341c695bee2f52"}, ] [package.dependencies] @@ -4569,23 +4579,22 @@ pyyaml = ">=5.1" regex = "!=2019.12.17" requests = "*" safetensors = ">=0.3.1" -tokenizers = ">=0.14,<0.15" +tokenizers = ">=0.14,<0.19" tqdm = ">=4.27" [package.extras] accelerate = ["accelerate (>=0.20.3)"] agents = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.10,!=1.12.0)"] -all = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] +all = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] codecarbon = ["codecarbon (==1.2.0)"] deepspeed = ["accelerate (>=0.20.3)", "deepspeed (>=0.9.3)"] -deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.15)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -docs = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.19)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +docs = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] docs-specific = ["hf-doc-builder"] -fairscale = ["fairscale (>0.3)"] flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] ftfy = ["ftfy"] @@ -4605,16 +4614,16 @@ serving = ["fastapi", "pydantic (<2)", "starlette", "uvicorn"] sigopt = ["sigopt"] sklearn = ["scikit-learn"] speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -testing = ["GitPython (<3.1.19)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "timeout-decorator"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "tensorboard", "timeout-decorator"] tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx"] tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx"] tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] timm = ["timm"] -tokenizers = ["tokenizers (>=0.14,<0.15)"] +tokenizers = ["tokenizers (>=0.14,<0.19)"] torch = ["accelerate (>=0.20.3)", "torch (>=1.10,!=1.12.0)"] torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] torch-vision = ["Pillow (<10.0.0)", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.16.4,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "tqdm (>=4.27)"] +torchhub = ["filelock", "huggingface-hub (>=0.16.4,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "tqdm (>=4.27)"] video = ["av (==9.2.0)", "decord (==0.6.0)"] vision = ["Pillow (<10.0.0)"] @@ -4747,30 +4756,29 @@ dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake [[package]] name = "urllib3" -version = "2.0.7" +version = "2.1.0" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "urllib3-2.0.7-py3-none-any.whl", hash = "sha256:fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e"}, - {file = "urllib3-2.0.7.tar.gz", hash = "sha256:c97dfde1f7bd43a71c8d2a58e369e9b2bf692d1334ea9f9cae55add7d0dd0f84"}, + {file = "urllib3-2.1.0-py3-none-any.whl", hash = "sha256:55901e917a5896a349ff771be919f8bd99aff50b79fe58fec595eb37bbc56bb3"}, + {file = "urllib3-2.1.0.tar.gz", hash = "sha256:df7aa8afb0148fa78488e7899b2c59b5f4ffcfa82e6c54ccb9dd37c1d7b52d54"}, ] [package.extras] brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] -secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] [[package]] name = "wandb" -version = "0.15.12" +version = "0.16.1" description = "A CLI and library for interacting with the Weights & Biases API." optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" files = [ - {file = "wandb-0.15.12-py3-none-any.whl", hash = "sha256:75c57b5bb8ddae21d45a02f644628585bdd112fea686de3177099a0996f1c41c"}, - {file = "wandb-0.15.12.tar.gz", hash = "sha256:c344d92fb8044b072a6138afd9adc5d3801ad050cf11378fe2af2fe899dcca84"}, + {file = "wandb-0.16.1-py3-none-any.whl", hash = "sha256:1d7423f92520984585bae9693bb637ae08d3e0c1d75ad4b34215bc44431f114c"}, + {file = "wandb-0.16.1.tar.gz", hash = "sha256:ffe6e8dd8cc8fcd72010c1246fb3d6d226b37c4f111f3f94308a1c0ae28a2fec"}, ] [package.dependencies] @@ -4778,7 +4786,6 @@ appdirs = ">=1.4.3" Click = ">=7.1,<8.0.0 || >8.0.0" docker-pycreds = ">=0.4.0" GitPython = ">=1.0.0,<3.1.29 || >3.1.29" -pathtools = "*" protobuf = [ {version = ">=3.12.0,<4.21.0 || >4.21.0,<5", markers = "python_version < \"3.9\" and sys_platform == \"linux\""}, {version = ">=3.19.0,<4.21.0 || >4.21.0,<5", markers = "python_version > \"3.9\" or sys_platform != \"linux\""}, @@ -4793,27 +4800,27 @@ setuptools = "*" typing-extensions = {version = "*", markers = "python_version < \"3.10\""} [package.extras] -async = ["httpx (>=0.22.0)"] +async = ["httpx (>=0.23.0)"] aws = ["boto3"] azure = ["azure-identity", "azure-storage-blob"] +core = ["wandb-core (>=0.17.0b2)"] gcp = ["google-cloud-storage"] kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"] -launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "nbconvert", "nbformat", "optuna", "typing-extensions"] +launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "typing-extensions"] media = ["bokeh", "moviepy", "numpy", "pillow", "plotly", "rdkit-pypi", "soundfile"] models = ["cloudpickle"] -nexus = ["wandb-core (>=0.16.0b1)"] perf = ["orjson"] sweeps = ["sweeps (>=0.2.0)"] [[package]] name = "wcwidth" -version = "0.2.8" +version = "0.2.12" description = "Measures the displayed width of unicode strings in a terminal" optional = false python-versions = "*" files = [ - {file = "wcwidth-0.2.8-py2.py3-none-any.whl", hash = "sha256:77f719e01648ed600dfa5402c347481c0992263b81a027344f3e1ba25493a704"}, - {file = "wcwidth-0.2.8.tar.gz", hash = "sha256:8705c569999ffbb4f6a87c6d1b80f324bd6db952f5eb0b95bc07517f4c1813d4"}, + {file = "wcwidth-0.2.12-py2.py3-none-any.whl", hash = "sha256:f26ec43d96c8cbfed76a5075dac87680124fa84e0855195a6184da9c187f133c"}, + {file = "wcwidth-0.2.12.tar.gz", hash = "sha256:f01c104efdf57971bcb756f054dd58ddec5204dd15fa31d6503ea57947d97c02"}, ] [[package]] @@ -4844,13 +4851,13 @@ files = [ [[package]] name = "websocket-client" -version = "1.6.4" +version = "1.7.0" description = "WebSocket client for Python with low level API options" optional = false python-versions = ">=3.8" files = [ - {file = "websocket-client-1.6.4.tar.gz", hash = "sha256:b3324019b3c28572086c4a319f91d1dcd44e6e11cd340232978c684a7650d0df"}, - {file = "websocket_client-1.6.4-py3-none-any.whl", hash = "sha256:084072e0a7f5f347ef2ac3d8698a5e0b4ffbfcab607628cadabc650fc9a83a24"}, + {file = "websocket-client-1.7.0.tar.gz", hash = "sha256:10e511ea3a8c744631d3bd77e61eb17ed09304c413ad42cf6ddfa4c7787e8fe6"}, + {file = "websocket_client-1.7.0-py3-none-any.whl", hash = "sha256:f4c3d22fec12a2461427a29957ff07d35098ee2d976d3ba244e688b8b4057588"}, ] [package.extras] @@ -4988,85 +4995,101 @@ files = [ [[package]] name = "yarl" -version = "1.9.2" +version = "1.9.4" description = "Yet another URL library" optional = false python-versions = ">=3.7" files = [ - {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8c2ad583743d16ddbdf6bb14b5cd76bf43b0d0006e918809d5d4ddf7bde8dd82"}, - {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:82aa6264b36c50acfb2424ad5ca537a2060ab6de158a5bd2a72a032cc75b9eb8"}, - {file = "yarl-1.9.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c0c77533b5ed4bcc38e943178ccae29b9bcf48ffd1063f5821192f23a1bd27b9"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee4afac41415d52d53a9833ebae7e32b344be72835bbb589018c9e938045a560"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9bf345c3a4f5ba7f766430f97f9cc1320786f19584acc7086491f45524a551ac"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a96c19c52ff442a808c105901d0bdfd2e28575b3d5f82e2f5fd67e20dc5f4ea"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:891c0e3ec5ec881541f6c5113d8df0315ce5440e244a716b95f2525b7b9f3608"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c3a53ba34a636a256d767c086ceb111358876e1fb6b50dfc4d3f4951d40133d5"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:566185e8ebc0898b11f8026447eacd02e46226716229cea8db37496c8cdd26e0"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2b0738fb871812722a0ac2154be1f049c6223b9f6f22eec352996b69775b36d4"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:32f1d071b3f362c80f1a7d322bfd7b2d11e33d2adf395cc1dd4df36c9c243095"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:e9fdc7ac0d42bc3ea78818557fab03af6181e076a2944f43c38684b4b6bed8e3"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:56ff08ab5df8429901ebdc5d15941b59f6253393cb5da07b4170beefcf1b2528"}, - {file = "yarl-1.9.2-cp310-cp310-win32.whl", hash = "sha256:8ea48e0a2f931064469bdabca50c2f578b565fc446f302a79ba6cc0ee7f384d3"}, - {file = "yarl-1.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:50f33040f3836e912ed16d212f6cc1efb3231a8a60526a407aeb66c1c1956dde"}, - {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:646d663eb2232d7909e6601f1a9107e66f9791f290a1b3dc7057818fe44fc2b6"}, - {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aff634b15beff8902d1f918012fc2a42e0dbae6f469fce134c8a0dc51ca423bb"}, - {file = "yarl-1.9.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a83503934c6273806aed765035716216cc9ab4e0364f7f066227e1aaea90b8d0"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b25322201585c69abc7b0e89e72790469f7dad90d26754717f3310bfe30331c2"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:22a94666751778629f1ec4280b08eb11815783c63f52092a5953faf73be24191"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ec53a0ea2a80c5cd1ab397925f94bff59222aa3cf9c6da938ce05c9ec20428d"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:159d81f22d7a43e6eabc36d7194cb53f2f15f498dbbfa8edc8a3239350f59fe7"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:832b7e711027c114d79dffb92576acd1bd2decc467dec60e1cac96912602d0e6"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:95d2ecefbcf4e744ea952d073c6922e72ee650ffc79028eb1e320e732898d7e8"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d4e2c6d555e77b37288eaf45b8f60f0737c9efa3452c6c44626a5455aeb250b9"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:783185c75c12a017cc345015ea359cc801c3b29a2966c2655cd12b233bf5a2be"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:b8cc1863402472f16c600e3e93d542b7e7542a540f95c30afd472e8e549fc3f7"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:822b30a0f22e588b32d3120f6d41e4ed021806418b4c9f0bc3048b8c8cb3f92a"}, - {file = "yarl-1.9.2-cp311-cp311-win32.whl", hash = "sha256:a60347f234c2212a9f0361955007fcf4033a75bf600a33c88a0a8e91af77c0e8"}, - {file = "yarl-1.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:be6b3fdec5c62f2a67cb3f8c6dbf56bbf3f61c0f046f84645cd1ca73532ea051"}, - {file = "yarl-1.9.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:38a3928ae37558bc1b559f67410df446d1fbfa87318b124bf5032c31e3447b74"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac9bb4c5ce3975aeac288cfcb5061ce60e0d14d92209e780c93954076c7c4367"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3da8a678ca8b96c8606bbb8bfacd99a12ad5dd288bc6f7979baddd62f71c63ef"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13414591ff516e04fcdee8dc051c13fd3db13b673c7a4cb1350e6b2ad9639ad3"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf74d08542c3a9ea97bb8f343d4fcbd4d8f91bba5ec9d5d7f792dbe727f88938"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e7221580dc1db478464cfeef9b03b95c5852cc22894e418562997df0d074ccc"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:494053246b119b041960ddcd20fd76224149cfea8ed8777b687358727911dd33"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:52a25809fcbecfc63ac9ba0c0fb586f90837f5425edfd1ec9f3372b119585e45"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:e65610c5792870d45d7b68c677681376fcf9cc1c289f23e8e8b39c1485384185"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:1b1bba902cba32cdec51fca038fd53f8beee88b77efc373968d1ed021024cc04"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:662e6016409828ee910f5d9602a2729a8a57d74b163c89a837de3fea050c7582"}, - {file = "yarl-1.9.2-cp37-cp37m-win32.whl", hash = "sha256:f364d3480bffd3aa566e886587eaca7c8c04d74f6e8933f3f2c996b7f09bee1b"}, - {file = "yarl-1.9.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6a5883464143ab3ae9ba68daae8e7c5c95b969462bbe42e2464d60e7e2698368"}, - {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5610f80cf43b6202e2c33ba3ec2ee0a2884f8f423c8f4f62906731d876ef4fac"}, - {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b9a4e67ad7b646cd6f0938c7ebfd60e481b7410f574c560e455e938d2da8e0f4"}, - {file = "yarl-1.9.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:83fcc480d7549ccebe9415d96d9263e2d4226798c37ebd18c930fce43dfb9574"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fcd436ea16fee7d4207c045b1e340020e58a2597301cfbcfdbe5abd2356c2fb"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84e0b1599334b1e1478db01b756e55937d4614f8654311eb26012091be109d59"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3458a24e4ea3fd8930e934c129b676c27452e4ebda80fbe47b56d8c6c7a63a9e"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:838162460b3a08987546e881a2bfa573960bb559dfa739e7800ceeec92e64417"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f4e2d08f07a3d7d3e12549052eb5ad3eab1c349c53ac51c209a0e5991bbada78"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:de119f56f3c5f0e2fb4dee508531a32b069a5f2c6e827b272d1e0ff5ac040333"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:149ddea5abf329752ea5051b61bd6c1d979e13fbf122d3a1f9f0c8be6cb6f63c"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:674ca19cbee4a82c9f54e0d1eee28116e63bc6fd1e96c43031d11cbab8b2afd5"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:9b3152f2f5677b997ae6c804b73da05a39daa6a9e85a512e0e6823d81cdad7cc"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5415d5a4b080dc9612b1b63cba008db84e908b95848369aa1da3686ae27b6d2b"}, - {file = "yarl-1.9.2-cp38-cp38-win32.whl", hash = "sha256:f7a3d8146575e08c29ed1cd287068e6d02f1c7bdff8970db96683b9591b86ee7"}, - {file = "yarl-1.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:63c48f6cef34e6319a74c727376e95626f84ea091f92c0250a98e53e62c77c72"}, - {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:75df5ef94c3fdc393c6b19d80e6ef1ecc9ae2f4263c09cacb178d871c02a5ba9"}, - {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c027a6e96ef77d401d8d5a5c8d6bc478e8042f1e448272e8d9752cb0aff8b5c8"}, - {file = "yarl-1.9.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f3b078dbe227f79be488ffcfc7a9edb3409d018e0952cf13f15fd6512847f3f7"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59723a029760079b7d991a401386390c4be5bfec1e7dd83e25a6a0881859e716"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b03917871bf859a81ccb180c9a2e6c1e04d2f6a51d953e6a5cdd70c93d4e5a2a"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c1012fa63eb6c032f3ce5d2171c267992ae0c00b9e164efe4d73db818465fac3"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a74dcbfe780e62f4b5a062714576f16c2f3493a0394e555ab141bf0d746bb955"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8c56986609b057b4839968ba901944af91b8e92f1725d1a2d77cbac6972b9ed1"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2c315df3293cd521033533d242d15eab26583360b58f7ee5d9565f15fee1bef4"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:b7232f8dfbd225d57340e441d8caf8652a6acd06b389ea2d3222b8bc89cbfca6"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:53338749febd28935d55b41bf0bcc79d634881195a39f6b2f767870b72514caf"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:066c163aec9d3d073dc9ffe5dd3ad05069bcb03fcaab8d221290ba99f9f69ee3"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8288d7cd28f8119b07dd49b7230d6b4562f9b61ee9a4ab02221060d21136be80"}, - {file = "yarl-1.9.2-cp39-cp39-win32.whl", hash = "sha256:b124e2a6d223b65ba8768d5706d103280914d61f5cae3afbc50fc3dfcc016623"}, - {file = "yarl-1.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:61016e7d582bc46a5378ffdd02cd0314fb8ba52f40f9cf4d9a5e7dbef88dee18"}, - {file = "yarl-1.9.2.tar.gz", hash = "sha256:04ab9d4b9f587c06d801c2abfe9317b77cdf996c65a90d5e84ecc45010823571"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6ed1d525bfb91b3fc9b690c5a21bb52de28c018530ad85093cc488bee2dd2"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c38c9ddb6103ceae4e4498f9c08fac9b590c5c71b0370f98714768e22ac6fa66"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e09c9d74f4566e905a0b8fa668c58109f7624db96a2171f21747abc7524234"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8477c1ee4bd47c57d49621a062121c3023609f7a13b8a46953eb6c9716ca392"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff2c858f5f6a42c2a8e751100f237c5e869cbde669a724f2062d4c4ef93551"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54525ae423d7b7a8ee81ba189f131054defdb122cde31ff17477951464c1691c"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:801e9264d19643548651b9db361ce3287176671fb0117f96b5ac0ee1c3530d53"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e516dc8baf7b380e6c1c26792610230f37147bb754d6426462ab115a02944385"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7d5aaac37d19b2904bb9dfe12cdb08c8443e7ba7d2852894ad448d4b8f442863"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:54beabb809ffcacbd9d28ac57b0db46e42a6e341a030293fb3185c409e626b8b"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bac8d525a8dbc2a1507ec731d2867025d11ceadcb4dd421423a5d42c56818541"}, + {file = "yarl-1.9.4-cp310-cp310-win32.whl", hash = "sha256:7855426dfbddac81896b6e533ebefc0af2f132d4a47340cee6d22cac7190022d"}, + {file = "yarl-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:848cd2a1df56ddbffeb375535fb62c9d1645dde33ca4d51341378b3f5954429b"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:35a2b9396879ce32754bd457d31a51ff0a9d426fd9e0e3c33394bf4b9036b099"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c7d56b293cc071e82532f70adcbd8b61909eec973ae9d2d1f9b233f3d943f2c"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8a1c6c0be645c745a081c192e747c5de06e944a0d21245f4cf7c05e457c36e0"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b3c1ffe10069f655ea2d731808e76e0f452fc6c749bea04781daf18e6039525"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:549d19c84c55d11687ddbd47eeb348a89df9cb30e1993f1b128f4685cd0ebbf8"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7409f968456111140c1c95301cadf071bd30a81cbd7ab829169fb9e3d72eae9"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23a6d84d9d1738dbc6e38167776107e63307dfc8ad108e580548d1f2c587f42"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8b889777de69897406c9fb0b76cdf2fd0f31267861ae7501d93003d55f54fbe"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:03caa9507d3d3c83bca08650678e25364e1843b484f19986a527630ca376ecce"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e9035df8d0880b2f1c7f5031f33f69e071dfe72ee9310cfc76f7b605958ceb9"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c0ec0ed476f77db9fb29bca17f0a8fcc7bc97ad4c6c1d8959c507decb22e8572"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:ee04010f26d5102399bd17f8df8bc38dc7ccd7701dc77f4a68c5b8d733406958"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49a180c2e0743d5d6e0b4d1a9e5f633c62eca3f8a86ba5dd3c471060e352ca98"}, + {file = "yarl-1.9.4-cp311-cp311-win32.whl", hash = "sha256:81eb57278deb6098a5b62e88ad8281b2ba09f2f1147c4767522353eaa6260b31"}, + {file = "yarl-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d1d2532b340b692880261c15aee4dc94dd22ca5d61b9db9a8a361953d36410b1"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0d2454f0aef65ea81037759be5ca9947539667eecebca092733b2eb43c965a81"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:44d8ffbb9c06e5a7f529f38f53eda23e50d1ed33c6c869e01481d3fafa6b8142"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aaaea1e536f98754a6e5c56091baa1b6ce2f2700cc4a00b0d49eca8dea471074"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3777ce5536d17989c91696db1d459574e9a9bd37660ea7ee4d3344579bb6f129"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc5fc1eeb029757349ad26bbc5880557389a03fa6ada41703db5e068881e5f2"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea65804b5dc88dacd4a40279af0cdadcfe74b3e5b4c897aa0d81cf86927fee78"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa102d6d280a5455ad6a0f9e6d769989638718e938a6a0a2ff3f4a7ff8c62cc4"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09efe4615ada057ba2d30df871d2f668af661e971dfeedf0c159927d48bbeff0"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:008d3e808d03ef28542372d01057fd09168419cdc8f848efe2804f894ae03e51"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f5cb257bc2ec58f437da2b37a8cd48f666db96d47b8a3115c29f316313654ff"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:992f18e0ea248ee03b5a6e8b3b4738850ae7dbb172cc41c966462801cbf62cf7"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0e9d124c191d5b881060a9e5060627694c3bdd1fe24c5eecc8d5d7d0eb6faabc"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3986b6f41ad22988e53d5778f91855dc0399b043fc8946d4f2e68af22ee9ff10"}, + {file = "yarl-1.9.4-cp312-cp312-win32.whl", hash = "sha256:4b21516d181cd77ebd06ce160ef8cc2a5e9ad35fb1c5930882baff5ac865eee7"}, + {file = "yarl-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a9bd00dc3bc395a662900f33f74feb3e757429e545d831eef5bb280252631984"}, + {file = "yarl-1.9.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63b20738b5aac74e239622d2fe30df4fca4942a86e31bf47a81a0e94c14df94f"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d7f7de27b8944f1fee2c26a88b4dabc2409d2fea7a9ed3df79b67277644e17"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c74018551e31269d56fab81a728f683667e7c28c04e807ba08f8c9e3bba32f14"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ca06675212f94e7a610e85ca36948bb8fc023e458dd6c63ef71abfd482481aa5"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aef935237d60a51a62b86249839b51345f47564208c6ee615ed2a40878dccdd"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b134fd795e2322b7684155b7855cc99409d10b2e408056db2b93b51a52accc7"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d25039a474c4c72a5ad4b52495056f843a7ff07b632c1b92ea9043a3d9950f6e"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f7d6b36dd2e029b6bcb8a13cf19664c7b8e19ab3a58e0fefbb5b8461447ed5ec"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:957b4774373cf6f709359e5c8c4a0af9f6d7875db657adb0feaf8d6cb3c3964c"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d7eeb6d22331e2fd42fce928a81c697c9ee2d51400bd1a28803965883e13cead"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6a962e04b8f91f8c4e5917e518d17958e3bdee71fd1d8b88cdce74dd0ebbf434"}, + {file = "yarl-1.9.4-cp37-cp37m-win32.whl", hash = "sha256:f3bc6af6e2b8f92eced34ef6a96ffb248e863af20ef4fde9448cc8c9b858b749"}, + {file = "yarl-1.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4d7a90a92e528aadf4965d685c17dacff3df282db1121136c382dc0b6014d2"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ec61d826d80fc293ed46c9dd26995921e3a82146feacd952ef0757236fc137be"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8be9e837ea9113676e5754b43b940b50cce76d9ed7d2461df1af39a8ee674d9f"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bef596fdaa8f26e3d66af846bbe77057237cb6e8efff8cd7cc8dff9a62278bbf"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d47552b6e52c3319fede1b60b3de120fe83bde9b7bddad11a69fb0af7db32f1"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84fc30f71689d7fc9168b92788abc977dc8cefa806909565fc2951d02f6b7d57"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa9741085f635934f3a2583e16fcf62ba835719a8b2b28fb2917bb0537c1dfa"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:206a55215e6d05dbc6c98ce598a59e6fbd0c493e2de4ea6cc2f4934d5a18d130"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07574b007ee20e5c375a8fe4a0789fad26db905f9813be0f9fef5a68080de559"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5a2e2433eb9344a163aced6a5f6c9222c0786e5a9e9cac2c89f0b28433f56e23"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6ad6d10ed9b67a382b45f29ea028f92d25bc0bc1daf6c5b801b90b5aa70fb9ec"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:6fe79f998a4052d79e1c30eeb7d6c1c1056ad33300f682465e1b4e9b5a188b78"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a825ec844298c791fd28ed14ed1bffc56a98d15b8c58a20e0e08c1f5f2bea1be"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8619d6915b3b0b34420cf9b2bb6d81ef59d984cb0fde7544e9ece32b4b3043c3"}, + {file = "yarl-1.9.4-cp38-cp38-win32.whl", hash = "sha256:686a0c2f85f83463272ddffd4deb5e591c98aac1897d65e92319f729c320eece"}, + {file = "yarl-1.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:a00862fb23195b6b8322f7d781b0dc1d82cb3bcac346d1e38689370cc1cc398b"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:604f31d97fa493083ea21bd9b92c419012531c4e17ea6da0f65cacdcf5d0bd27"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8a854227cf581330ffa2c4824d96e52ee621dd571078a252c25e3a3b3d94a1b1"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ba6f52cbc7809cd8d74604cce9c14868306ae4aa0282016b641c661f981a6e91"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6327976c7c2f4ee6816eff196e25385ccc02cb81427952414a64811037bbc8b"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8397a3817d7dcdd14bb266283cd1d6fc7264a48c186b986f32e86d86d35fbac5"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0381b4ce23ff92f8170080c97678040fc5b08da85e9e292292aba67fdac6c34"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23d32a2594cb5d565d358a92e151315d1b2268bc10f4610d098f96b147370136"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ddb2a5c08a4eaaba605340fdee8fc08e406c56617566d9643ad8bf6852778fc7"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:26a1dc6285e03f3cc9e839a2da83bcbf31dcb0d004c72d0730e755b33466c30e"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:18580f672e44ce1238b82f7fb87d727c4a131f3a9d33a5e0e82b793362bf18b4"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:29e0f83f37610f173eb7e7b5562dd71467993495e568e708d99e9d1944f561ec"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:1f23e4fe1e8794f74b6027d7cf19dc25f8b63af1483d91d595d4a07eca1fb26c"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db8e58b9d79200c76956cefd14d5c90af54416ff5353c5bfd7cbe58818e26ef0"}, + {file = "yarl-1.9.4-cp39-cp39-win32.whl", hash = "sha256:c7224cab95645c7ab53791022ae77a4509472613e839dab722a72abe5a684575"}, + {file = "yarl-1.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:824d6c50492add5da9374875ce72db7a0733b29c2394890aef23d533106e2b15"}, + {file = "yarl-1.9.4-py3-none-any.whl", hash = "sha256:928cecb0ef9d5a7946eb6ff58417ad2fe9375762382f1bf5c55e61645f2c43ad"}, + {file = "yarl-1.9.4.tar.gz", hash = "sha256:566db86717cf8080b99b58b083b773a908ae40f06681e87e589a976faf8246bf"}, ] [package.dependencies] @@ -5091,4 +5114,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "7741ae68f922bc6c30656483e5337f5544c5f57c51a4102e067965014c806604" +content-hash = "cf88ba97e4847d4220e2fb639a587d62aa5a98e36fbfc632d7e3914cd08dcebb" diff --git a/pyproject.toml b/pyproject.toml index f89a4bcf3..fe494785d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,112 +1,96 @@ [tool.poetry] -name = "transformer-lens" -version = "0.0.0" # This is automatically set by the CD pipeline on release -description = "An implementation of transformers tailored for mechanistic interpretability." -authors = ["Neel Nanda <77788841+neelnanda-io@users.noreply.github.com>"] -license = "MIT" -readme = "README.md" -packages = [{include = "transformer_lens"}] + authors=["Neel Nanda <77788841+neelnanda-io@users.noreply.github.com>"] + description="An implementation of transformers tailored for mechanistic interpretability." + license="MIT" + name="transformer-lens" + packages=[{include="transformer_lens"}] + readme="README.md" + # Version is automatically set by the pipeline on release + version="0.0.0" -[tool.poetry.scripts] -build-docs = "docs.make_docs:build_docs" -docs-hot-reload = "docs.make_docs:docs_hot_reload" + [tool.poetry.scripts] + build-docs="docs.make_docs:build_docs" + docs-hot-reload="docs.make_docs:docs_hot_reload" -[tool.poetry.dependencies] -python = ">=3.8,<4.0" -einops = ">=0.6.0" -numpy = [{ version = ">=1.20,<1.25", python = ">=3.8,<3.9" }, - { version = ">=1.24", python = ">=3.9,<3.12" }, - { version = ">=1.26", python = ">=3.12,<3.13" }] -torch = ">=1.10,!=2.0,!=2.1.0" -# See PyTorch 2 fix below. We pin >=2.1.1 due to MPS errors (See our Slack) + [tool.poetry.dependencies] + accelerate=">=0.23.0" # Needed for Llama Models + beartype="^0.14.1" + datasets=">=2.7.1" + einops=">=0.6.0" + fancy-einsum=">=0.0.3" + jaxtyping=">=0.2.11" + numpy=[ + {version=">=1.20,<1.25", python=">=3.8,<3.9"}, + {version=">=1.24", python=">=3.9,<3.12"}, + {version=">=1.26", python=">=3.12,<3.13"}, + ] + pandas=">=1.1.5" + python=">=3.8,<4.0" + rich=">=12.6.0" + torch=">=1.10,!=2.0,!=2.1.0" # Pin >=2.1.1 due to known MPS errors on 2.1.0 + tqdm=">=4.64.1" + transformers=">=4.25.1" + typing-extensions="*" + wandb=">=0.13.5" -datasets = ">=2.7.1" -transformers = ">=4.25.1" -tqdm = ">=4.64.1" -pandas = ">=1.1.5" -wandb = ">=0.13.5" -fancy-einsum = ">=0.0.3" -rich = ">=12.6.0" -jaxtyping = ">=0.2.11" -beartype = "^0.14.1" -accelerate = ">=0.23.0" # Needed for Llama Models -typing-extensions = "*" -# PyTorch 2.0 Bug Fix PyTorch didn't put their dependencies metadata into all wheels for 2.1.0, so -# it doesn't work with Poetry. This is a known bug - the workaround is to place them manually here -# (from the one wheel that did correctly list them). This was broken in 2.0.1 and the fix wasn't -# made for 2.1.0, however Meta are aware of the issue and once it is fixed (and the torch version -# requirement bumped) this should be removed. Note also the python version is used to specify that -# this is only added where v2 torch is installed (as per the torch version requirement above). -# https://github.com/pytorch/pytorch/issues/100974 -# https://github.com/python-poetry/poetry/issues/7902#issuecomment-1583078794 -nvidia-cuda-nvrtc-cu12 = { version = ">=12.1.105", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" } -nvidia-cuda-runtime-cu12 = { version = ">=12.1.105", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" } -nvidia-cuda-cupti-cu12 = { version = ">=12.1.105", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" } -nvidia-cudnn-cu12 = { version = ">=8.9.2.26", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" } -nvidia-cublas-cu12 = { version = ">=12.1.3.1", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" } -nvidia-cufft-cu12 = { version = ">=11.0.2.54", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" } -nvidia-curand-cu12 = { version = ">=10.3.2.106", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" } -nvidia-cusolver-cu12 = { version = ">=11.4.5.107", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" } -nvidia-cusparse-cu12 = { version = ">=12.1.0.106", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" } -nvidia-nccl-cu12 = { version = ">=2.18.1", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" } -nvidia-nvtx-cu12 = { version = ">=12.1.105", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" } -triton = { version = ">=2.1.0", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" } -# End PyTorch 2.1.0 Bug Fix + [tool.poetry.group] + [tool.poetry.group.dev.dependencies] + black="^23.3.0" + circuitsvis=">=1.38.1" + isort="5.8.0" + jupyter=">=1.0.0" + mypy=">=0.991" + nbval="^0.10.0" + plotly=">=5.12.0" + pycln="^2.1.3" + pytest=">=7.2.0" + pytest-cov=">=4.0.0" + pytest-doctestplus="^1.0.0" -[tool.poetry.group.dev.dependencies] -pytest = ">=7.2.0" -pytest-cov = ">=4.0.0" -mypy = ">=0.991" -jupyter = ">=1.0.0" -circuitsvis = ">=1.38.1" -plotly = ">=5.12.0" -isort = "5.8.0" -black = "^23.3.0" -pycln = "^2.1.3" -pytest-doctestplus = "^1.0.0" -nbval = "^0.10.0" + [tool.poetry.group.jupyter.dependencies] + ipywidgets="^8.1.1" + jupyterlab=">=3.5.0" -[tool.poetry.group.jupyter.dependencies] -jupyterlab = ">=3.5.0" -ipywidgets = "^8.1.1" + [tool.poetry.group.docs.dependencies] + furo={version=">=2022.12.7"} + myst-parser={version=">=0.18.1"} + nbconvert="^7.9.2" + nbsphinx="^0.9.3" + pandoc="^2.3" + snowballstemmer="*" + sphinx={version="5.2.3"} + sphinx-autobuild={version=">=2021.3.14"} + sphinxcontrib-napoleon={version=">=0.7"} + tabulate={version=">=0.9.0"} -[tool.poetry.group.docs.dependencies] -sphinx = {version = "5.2.3" } -sphinxcontrib-napoleon = {version = ">=0.7" } -sphinx-autobuild = {version = ">=2021.3.14" } -furo = {version = ">=2022.12.7" } -myst-parser = {version = ">=0.18.1" } -tabulate= {version = ">=0.9.0" } -snowballstemmer = "*" -nbsphinx = "^0.9.3" -pandoc = "^2.3" -nbconvert = "^7.9.2" - -[tool.pytest.ini_options] -doctest_optionflags = "NORMALIZE_WHITESPACE ELLIPSIS FLOAT_CMP" -filterwarnings = [ - "ignore:pkg_resources is deprecated as an API:DeprecationWarning", - # Ignore numpy.distutils deprecation warning caused by pandas - # More info: https://numpy.org/doc/stable/reference/distutils.html#module-numpy.distutils - "ignore:distutils Version classes are deprecated:DeprecationWarning" -] -addopts = """--jaxtyping-packages=transformer_lens,beartype.beartype \ --W ignore::beartype.roar.BeartypeDecorHintPep585DeprecationWarning \ ---deselect transformer_lens/utils.py::test_prompt \ ---doctest-modules --doctest-plus \ ---nbval""" +[tool.pytest] + [tool.pytest.ini_options] + addopts=[ + "--doctest-modules", + "--doctest-plus", + "--jaxtyping-packages=transformer_lens,beartype.beartype", + "--nbval", + "-W ignore::beartype.roar.BeartypeDecorHintPep585DeprecationWarning", + ] + doctest_optionflags="NORMALIZE_WHITESPACE ELLIPSIS FLOAT_CMP" + filterwarnings=[ + "ignore:pkg_resources is deprecated as an API:DeprecationWarning", + # Ignore numpy.distutils deprecation warning caused by pandas + # More info: https://numpy.org/doc/stable/reference/distutils.html#module-numpy.distutils + "ignore:distutils Version classes are deprecated:DeprecationWarning", + ] [tool.isort] -profile = "black" -extend_skip = ["__init__.py", ".venv/"] + extend_skip=[".venv/", "__init__.py"] + profile="black" [tool.mypy] -ignore_missing_imports = true -check_untyped_defs = true + check_untyped_defs=true + ignore_missing_imports=true [tool.black] -# Exclude snapshot tests & .venv -exclude = ''' + # Exclude snapshot tests & .venv + exclude=''' ( /snapshots/ | .venv/ @@ -115,21 +99,21 @@ exclude = ''' [tool.pylint] [tool.pylint.TYPECHECK] - # Fix for Pytorch member existence checks - generated-members = "torch.*" + # Fix for Pytorch member existence checks + generated-members="torch.*" [tool.pylint.DESIGN] - max-args = 10 - max-locals = 30 + max-args=10 + max-locals=30 [tool.pylint."MESSAGES CONTROL"] - disable = "redefined-builtin" # Disable redefined builtin functions + disable="redefined-builtin" # Disable redefined builtin functions [tool.pylint.'MASTER'] - disable = [ - "C0103", # Disable invalid file name (as we use PascalCase for classes) - ] + disable=[ + "C0103", # Disable invalid file name (as we use PascalCase for classes) + ] [build-system] -requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" + build-backend="poetry.core.masonry.api" + requires=["poetry-core"] diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 5306e502f..fcf4fcf3a 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -13,6 +13,7 @@ import einops import numpy as np +import pytest import torch import torch.nn.functional as F import transformers @@ -601,6 +602,7 @@ def remove_batch_dim( # Note: Docstring won't be tested with PyTest (it's ignored), as it thinks this is a regular unit # test (because it's name is prefixed `test_`). +@pytest.mark.skip def test_prompt( prompt: str, answer: str, From af99428ce64764b9c2d40098081604b27bb48595 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Mon, 11 Dec 2023 19:20:03 +0100 Subject: [PATCH 22/73] added import again --- transformer_lens/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_lens/__init__.py b/transformer_lens/__init__.py index affa9b69d..e2fb1484b 100644 --- a/transformer_lens/__init__.py +++ b/transformer_lens/__init__.py @@ -5,6 +5,7 @@ HookedTransformerKeyValueCache, HookedTransformerKeyValueCacheEntry, ) +from . import components from .HookedTransformerConfig import HookedTransformerConfig from .FactoredMatrix import FactoredMatrix from .ActivationCache import ActivationCache From 11f2088d8d8c93b58309362dff4a85732f62748f Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Mon, 11 Dec 2023 19:25:48 +0100 Subject: [PATCH 23/73] updated attention component --- transformer_lens/components/attention.py | 70 +++++++++--------------- 1 file changed, 25 insertions(+), 45 deletions(-) diff --git a/transformer_lens/components/attention.py b/transformer_lens/components/attention.py index 115f70f17..b33c985fb 100644 --- a/transformer_lens/components/attention.py +++ b/transformer_lens/components/attention.py @@ -37,12 +37,9 @@ def __init__( layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. """ super().__init__() - - self.cached_alibi = None - self.cfg = ( - HookedTransformerConfig.from_dict(cfg) if isinstance(cfg, Dict) else cfg - ) - + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg self.W_Q = nn.Parameter( torch.empty( self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype @@ -74,29 +71,31 @@ def __init__( ) self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype)) + self.attn_type = attn_type # Create a max_ctx x max_ctx mask, with True iff that query position # can attend to that key position (query is first axis, key is second axis) causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool()) - - if attn_type == "global": + if self.attn_type == "global": # For global attention, this is a lower triangular matrix - key <= query self.register_buffer("mask", causal_mask) - elif attn_type == "local": + elif self.attn_type == "local": # For local, this is banded, query - window_size < key <= query assert isinstance(self.cfg.window_size, int) self.register_buffer( "mask", torch.triu(causal_mask, 1 - self.cfg.window_size) ) else: - raise ValueError(f"Invalid attention type: {attn_type}") + raise ValueError(f"Invalid attention type: {self.attn_type}") self.register_buffer("IGNORE", torch.tensor(-torch.inf)) self.layer_id = layer_id # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability? - self.attn_scale = np.sqrt(self.cfg.d_head) if self.cfg.use_attn_scale else 1.0 - + if self.cfg.use_attn_scale: + self.attn_scale = np.sqrt(self.cfg.d_head) + else: + self.attn_scale = 1.0 if self.cfg.scale_attn_by_inverse_layer_idx: self.attn_scale *= self.layer_id + 1 @@ -121,6 +120,10 @@ def __init__( ) self.register_buffer("rotary_sin", sin) self.register_buffer("rotary_cos", cos) + elif self.cfg.positional_embedding_type == "alibi": + # ALiBi bias wil be constructed on the first forward pass. + # Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage. + self.alibi = None @property def OV(self) -> FactoredMatrix: @@ -176,7 +179,6 @@ def forward( qkv_einops_string = "batch pos head_index d_model" else: qkv_einops_string = "batch pos d_model" - q = self.hook_q( einsum( f"{qkv_einops_string}, head_index d_model d_head \ @@ -242,9 +244,13 @@ def forward( # The key context length is the number of positions in the past - this includes all positions in the cache key_ctx = attn_scores.size(-1) - alibi = self.get_cached_alibi(key_ctx=key_ctx) + # only recompute when necessary to increase efficiency. + if self.alibi is None or key_ctx > self.alibi.size(-1): + self.alibi = Attention.create_alibi_bias( + self.cfg.n_heads, key_ctx, self.cfg.device + ) - attn_scores += alibi[ + attn_scores += self.alibi[ :, :query_ctx, :key_ctx ] # [batch, head_index, query_pos, key_pos] @@ -253,7 +259,6 @@ def forward( attn_scores = self.apply_causal_mask( attn_scores, kv_cache_pos_offset, attention_mask ) # [batch, head_index, query_pos, key_pos] - if additive_attention_mask is not None: attn_scores += additive_attention_mask @@ -271,9 +276,8 @@ def forward( pattern, ) ) # [batch, pos, head_index, d_head] - if not self.cfg.use_attn_result: - return ( + out = ( ( einsum( "batch pos head_index d_head, \ @@ -297,12 +301,13 @@ def forward( self.W_O, ) ) # [batch, pos, head_index, d_model] - return ( + out = ( einops.reduce( result, "batch position index model->batch position model", "sum" ) + self.b_O ) # [batch, pos, d_model] + return out def apply_causal_mask( self, @@ -326,7 +331,6 @@ def apply_causal_mask( final_mask = self.mask[ None, None, -query_ctx_length:, -key_ctx_length: ] # [1, 1, pos, pos] - if attention_mask is not None: # Apply a causal mask to the attention scores considering the padding einsum_str = "batch head pos offset_pos, batch offset_pos -> batch head pos offset_pos" @@ -359,10 +363,8 @@ def calculate_sin_cos_rotary( freq = einops.repeat(freq, "d -> (2 d)") else: freq = einops.repeat(freq, "d -> (d 2)") - # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency angles = pos[:, None] / freq[None, :] - return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) def rotate_every_two( @@ -544,26 +546,4 @@ def create_alibi_bias( # The ALiBi bias is then m * slope_matrix alibi_bias = torch.einsum("ij,k->kij", slope, multipliers) - return alibi_bias - - def get_cached_alibi( - self, key_ctx: int - ) -> Float[torch.Tensor, "head_idx query key"]: - """Get A Cached ALiBi bias For Calculation. - - This function will check for if an instance of our ALiBi bias is currently set. - If the ALiBi bias is not set or if our key context is greater than it's cached size, a new - instance will be initiated. - - The cached ALiBi bias is then returned - - Returns: - The ALiBi bias that should be added to the attention scores before the softmax. - """ - # only recompute when necessary to increase efficiency. - if self.cached_alibi is None or key_ctx > self.cached_alibi.size(-1): - self.cached_alibi = Attention.create_alibi_bias( - self.cfg.n_heads, key_ctx, self.cfg.device - ) - - return self.cached_alibi + return alibi_bias \ No newline at end of file From 98486bd8a657b40b0fb17f795a1c61bdb2ecff5d Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Mon, 11 Dec 2023 19:33:36 +0100 Subject: [PATCH 24/73] added new line --- transformer_lens/components/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/components/attention.py b/transformer_lens/components/attention.py index b33c985fb..0a43169ef 100644 --- a/transformer_lens/components/attention.py +++ b/transformer_lens/components/attention.py @@ -546,4 +546,4 @@ def create_alibi_bias( # The ALiBi bias is then m * slope_matrix alibi_bias = torch.einsum("ij,k->kij", slope, multipliers) - return alibi_bias \ No newline at end of file + return alibi_bias From 724104234c4783c63ede140159baa12d90e432bb Mon Sep 17 00:00:00 2001 From: Aaquib Syed <47124521+Aaquib111@users.noreply.github.com> Date: Tue, 16 Jan 2024 15:31:04 -0800 Subject: [PATCH 25/73] Closes #478: Adding the Qwen family of models (#477) * Fixing numerical issues * Added qwen lol * setup local * allclose * Added qwen * Cleaned up implementation * removed untested models * Cleaned up implementation removed untested models * commented untested models * formatting * fixed mem issues + trust_remote_code * formatting * merge * Force rerun checks --------- Co-authored-by: Andy Arditi --- demos/Qwen.ipynb | 372 ++++++++++++++++++++ transformer_lens/HookedTransformer.py | 5 +- transformer_lens/HookedTransformerConfig.py | 2 + transformer_lens/components.py | 15 +- transformer_lens/loading_from_pretrained.py | 120 ++++++- transformer_lens/utils.py | 4 +- 6 files changed, 503 insertions(+), 15 deletions(-) create mode 100644 demos/Qwen.ipynb diff --git a/demos/Qwen.ipynb b/demos/Qwen.ipynb new file mode 100644 index 000000000..d49b39578 --- /dev/null +++ b/demos/Qwen.ipynb @@ -0,0 +1,372 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting tiktoken\n", + " Downloading tiktoken-0.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)\n", + "Collecting regex>=2022.1.18 (from tiktoken)\n", + " Downloading regex-2023.12.25-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.9/40.9 kB\u001b[0m \u001b[31m718.2 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m \u001b[36m0:00:01\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: requests>=2.26.0 in /opt/conda/lib/python3.10/site-packages (from tiktoken) (2.31.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests>=2.26.0->tiktoken) (2.0.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests>=2.26.0->tiktoken) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests>=2.26.0->tiktoken) (1.26.18)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests>=2.26.0->tiktoken) (2023.11.17)\n", + "Downloading tiktoken-0.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hDownloading regex-2023.12.25-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (773 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m774.0/774.0 kB\u001b[0m \u001b[31m11.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", + "\u001b[?25hInstalling collected packages: regex, tiktoken\n", + "Successfully installed regex-2023.12.25 tiktoken-0.5.2\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n", + "Collecting transformers_stream_generator\n", + " Downloading transformers-stream-generator-0.0.4.tar.gz (12 kB)\n", + " Preparing metadata (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25hCollecting transformers>=4.26.1 (from transformers_stream_generator)\n", + " Downloading transformers-4.36.2-py3-none-any.whl.metadata (126 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m126.8/126.8 kB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from transformers>=4.26.1->transformers_stream_generator) (3.13.1)\n", + "Collecting huggingface-hub<1.0,>=0.19.3 (from transformers>=4.26.1->transformers_stream_generator)\n", + " Downloading huggingface_hub-0.20.2-py3-none-any.whl.metadata (12 kB)\n", + "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from transformers>=4.26.1->transformers_stream_generator) (1.26.2)\n", + "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from transformers>=4.26.1->transformers_stream_generator) (23.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from transformers>=4.26.1->transformers_stream_generator) (6.0.1)\n", + "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.10/site-packages (from transformers>=4.26.1->transformers_stream_generator) (2023.12.25)\n", + "Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from transformers>=4.26.1->transformers_stream_generator) (2.31.0)\n", + "Collecting tokenizers<0.19,>=0.14 (from transformers>=4.26.1->transformers_stream_generator)\n", + " Downloading tokenizers-0.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\n", + "Collecting safetensors>=0.3.1 (from transformers>=4.26.1->transformers_stream_generator)\n", + " Downloading safetensors-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)\n", + "Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.10/site-packages (from transformers>=4.26.1->transformers_stream_generator) (4.65.0)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.19.3->transformers>=4.26.1->transformers_stream_generator) (2023.12.2)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.19.3->transformers>=4.26.1->transformers_stream_generator) (4.7.1)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->transformers>=4.26.1->transformers_stream_generator) (2.0.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->transformers>=4.26.1->transformers_stream_generator) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->transformers>=4.26.1->transformers_stream_generator) (1.26.18)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->transformers>=4.26.1->transformers_stream_generator) (2023.11.17)\n", + "Downloading transformers-4.36.2-py3-none-any.whl (8.2 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.2/8.2 MB\u001b[0m \u001b[31m64.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hDownloading huggingface_hub-0.20.2-py3-none-any.whl (330 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m330.3/330.3 kB\u001b[0m \u001b[31m28.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading safetensors-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m63.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading tokenizers-0.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.8/3.8 MB\u001b[0m \u001b[31m77.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m\n", + "\u001b[?25hBuilding wheels for collected packages: transformers_stream_generator\n", + " Building wheel for transformers_stream_generator (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25h Created wheel for transformers_stream_generator: filename=transformers_stream_generator-0.0.4-py3-none-any.whl size=12315 sha256=44d1037124d6e69b847e846035b01ac56e5ebf6d4b115a332c16e85d50c4dc42\n", + " Stored in directory: /root/.cache/pip/wheels/47/1d/3c/92d88493ed40c0d9be60a391eb76c9a56e9f9b7542cb789401\n", + "Successfully built transformers_stream_generator\n", + "Installing collected packages: safetensors, huggingface-hub, tokenizers, transformers, transformers_stream_generator\n", + "Successfully installed huggingface-hub-0.20.2 safetensors-0.4.1 tokenizers-0.15.0 transformers-4.36.2 transformers_stream_generator-0.0.4\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install transformers_stream_generator plotly circuitsvis huggingface_hub einops tiktoken datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running as a Jupyter notebook - intended for development only!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_11422/410710250.py:21: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + " ipython.magic(\"load_ext autoreload\")\n", + "/tmp/ipykernel_11422/410710250.py:22: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + " ipython.magic(\"autoreload 2\")\n" + ] + } + ], + "source": [ + "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "DEVELOPMENT_MODE = False\n", + "try:\n", + " import google.colab\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + " %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n", + " %pip install circuitsvis\n", + " \n", + " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", + " # # Install another version of node that makes PySvelte work way faster\n", + " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", + " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", + "except:\n", + " IN_COLAB = False\n", + " print(\"Running as a Jupyter notebook - intended for development only!\")\n", + " from IPython import get_ipython\n", + "\n", + " ipython = get_ipython()\n", + " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", + " ipython.magic(\"load_ext autoreload\")\n", + " ipython.magic(\"autoreload 2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using renderer: colab\n" + ] + } + ], + "source": [ + "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", + "import plotly.io as pio\n", + "if IN_COLAB or not DEVELOPMENT_MODE:\n", + " pio.renderers.default = \"colab\"\n", + "else:\n", + " pio.renderers.default = \"notebook_connected\"\n", + "print(f\"Using renderer: {pio.renderers.default}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/root/TransformerLens\n" + ] + } + ], + "source": [ + "%cd ~/TransformerLens\n", + "import torch\n", + "torch.set_grad_enabled(False)\n", + "\n", + "from transformers import AutoTokenizer\n", + "from transformer_lens import HookedTransformer\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "from transformers.generation import GenerationConfig\n", + "\n", + "from functools import partial" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Try importing flash-attention for faster inference...\n", + "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary\n", + "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm\n", + "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency https://github.com/Dao-AILab/flash-attention\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "943f344bd8c141738f6f3bd9db5c8514", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/8 [00:00 Float[torch.Tensor, "batch pos length"]: if self.cfg.dtype not in [torch.float32, torch.float64]: x = x.to(torch.float32) - scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() ) @@ -722,10 +721,10 @@ def calculate_sin_cos_rotary( # A set of frequencies evenly spaced in log space freq = base ** (dim / (rotary_dim / 2)) - if self.cfg.original_architecture in ["GPTNeoXForCausalLM", "LlamaForCausalLM"]: - freq = einops.repeat(freq, "d -> (2 d)") - else: + if self.cfg.rotary_adjacent_pairs: freq = einops.repeat(freq, "d -> (d 2)") + else: + freq = einops.repeat(freq, "d -> (2 d)") # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency angles = pos[:, None] / freq[None, :] return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) @@ -741,13 +740,13 @@ def rotate_every_two( GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details. """ rot_x = x.clone() - if self.cfg.original_architecture in ["GPTNeoXForCausalLM", "LlamaForCausalLM"]: + if self.cfg.rotary_adjacent_pairs: + rot_x[..., ::2] = -x[..., 1::2] + rot_x[..., 1::2] = x[..., ::2] + else: n = x.size(-1) // 2 rot_x[..., :n] = -x[..., n:] rot_x[..., n:] = x[..., :n] - else: - rot_x[..., ::2] = -x[..., 1::2] - rot_x[..., 1::2] = x[..., ::2] return rot_x diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 123582f1a..56c23955a 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -138,6 +138,12 @@ "stabilityai/stablelm-tuned-alpha-7b", "bigscience/bloom-560m", "bigcode/santacoder", + "Qwen/Qwen-1_8B", + "Qwen/Qwen-7B", + "Qwen/Qwen-14B", + "Qwen/Qwen-1_8B-Chat", + "Qwen/Qwen-7B-Chat", + "Qwen/Qwen-14B-Chat", ] """Official model names for models on HuggingFace.""" @@ -498,6 +504,12 @@ ], "bigscience/bloom-560m": ["bloom-560m"], "bigcode/santacoder": ["santacoder"], + "Qwen/Qwen-1_8B": ["qwen-1.8b"], + "Qwen/Qwen-7B": ["qwen-7b"], + "Qwen/Qwen-14B": ["qwen-14b"], + "Qwen/Qwen-1_8B-Chat": ["qwen-1.8b-chat"], + "Qwen/Qwen-7B-Chat": ["qwen-7b-chat"], + "Qwen/Qwen-14B-Chat": ["qwen-14b-chat"], } """Model aliases for models on HuggingFace.""" @@ -507,6 +519,8 @@ for name in OFFICIAL_MODEL_NAMES ] +NEED_REMOTE_CODE_MODELS = ("bigcode/santacoder", "Qwen/Qwen-") + def make_model_alias_map(): """ @@ -566,6 +580,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "act_fn": "silu", "normalization_type": "RMS", "positional_embedding_type": "rotary", + "rotary_adjacent_pairs": False, "rotary_dim": 4096 // 32, "final_rms": True, "gated_mlp": True, @@ -585,6 +600,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "act_fn": "silu", "normalization_type": "RMS", "positional_embedding_type": "rotary", + "rotary_adjacent_pairs": False, "rotary_dim": 5120 // 40, "final_rms": True, "gated_mlp": True, @@ -602,6 +618,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "act_fn": "silu", "normalization_type": "RMS", "positional_embedding_type": "rotary", + "rotary_adjacent_pairs": False, "rotary_dim": 6656 // 52, "final_rms": True, "gated_mlp": True, @@ -620,6 +637,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "normalization_type": "RMS", "positional_embedding_type": "rotary", "rotary_dim": 8192 // 64, + "rotary_adjacent_pairs": False, "final_rms": True, "gated_mlp": True, } @@ -690,6 +708,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "parallel_attn_mlp": True, "positional_embedding_type": "rotary", "rotary_dim": hf_config.rotary_dim, + "rotary_adjacent_pairs": True, "normalization_type": "LN", } elif architecture == "GPTNeoXForCausalLM": @@ -708,6 +727,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "scale_attn_by_inverse_layer_idx": False, "parallel_attn_mlp": True, "positional_embedding_type": "rotary", + "rotary_adjacent_pairs": False, "normalization_type": "LN", } rotary_pct = hf_config.rotary_pct @@ -755,9 +775,33 @@ def convert_hf_model_config(model_name: str, **kwargs): "act_fn": hf_config.activation_function, "use_attn_scale": True, "use_local_attn": False, + "trust_remote_code": "santacoder" + in official_model_name, # Only santacoder needs trust_remote_code "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, "normalization_type": "LN", } + elif architecture == "QWenLMHeadModel": + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size // 2, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big + "eps": hf_config.layer_norm_epsilon, + "d_vocab": hf_config.vocab_size, + "act_fn": "silu", + "use_attn_scale": hf_config.scale_attn_weights, + "initializer_range": hf_config.initializer_range, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "rotary_dim": hf_config.kv_channels, + "rotary_adjacent_pairs": False, + "tokenizer_prepends_bos": True, + "trust_remote_code": True, + "final_rms": True, + "gated_mlp": True, + } else: raise NotImplementedError(f"{architecture} is not currently supported.") # All of these models use LayerNorm @@ -861,6 +905,13 @@ def get_pretrained_model_config( ): cfg_dict = convert_neel_model_config(official_model_name, **kwargs) else: + if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( + "trust_remote_code", False + ): + logging.warning( + f"Loading model {official_model_name} requires setting trust_remote_code=True" + ) + kwargs["trust_remote_code"] = True cfg_dict = convert_hf_model_config(official_model_name, **kwargs) # Processing common to both model types # Remove any prefix, saying the organization who made a model. @@ -977,8 +1028,6 @@ def get_checkpoint_labels(model_name: str, **kwargs): # %% Loading state dicts - - def get_pretrained_state_dict( official_model_name: str, cfg: HookedTransformerConfig, @@ -1001,11 +1050,11 @@ def get_pretrained_state_dict( dtype = kwargs["torch_dtype"] del kwargs["torch_dtype"] official_model_name = get_official_model_name(official_model_name) - if official_model_name == "bigcode/santacoder" and not kwargs.get( + if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( "trust_remote_code", False ): logging.warning( - "Loading santacoder model requires setting trust_remote_code=True" + f"Loading model {official_model_name} state dict requires setting trust_remote_code=True" ) kwargs["trust_remote_code"] = True if ( @@ -1091,6 +1140,8 @@ def get_pretrained_state_dict( state_dict = convert_bloom_weights(hf_model, cfg) elif cfg.original_architecture == "GPT2LMHeadCustomModel": state_dict = convert_coder_weights(hf_model, cfg) + elif cfg.original_architecture == "QWenLMHeadModel": + state_dict = convert_qwen_weights(hf_model, cfg) else: raise ValueError( f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." @@ -1420,6 +1471,67 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig): return state_dict +def convert_qwen_weights(qwen, cfg: HookedTransformerConfig): + state_dict = {} + model = qwen.transformer + state_dict["embed.W_E"] = model.wte.weight + + for l in range(cfg.n_layers): + state_dict[f"blocks.{l}.ln1.w"] = model.h[l].ln_1.weight + + W_Q, W_K, W_V = model.h[l].attn.c_attn.weight.split( + split_size=cfg.d_model, dim=0 + ) + W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_heads) + W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn.W_K"] = W_K + state_dict[f"blocks.{l}.attn.W_V"] = W_V + + b_Q, b_K, b_V = model.h[l].attn.c_attn.bias.split(split_size=cfg.d_model, dim=0) + b_Q = einops.rearrange( + b_Q, + "(n_head d_head) -> n_head d_head", + n_head=cfg.n_heads, + ) + b_K = einops.rearrange( + b_K, + "(n_head d_head) -> n_head d_head", + n_head=cfg.n_heads, + ) + b_V = einops.rearrange( + b_V, + "(n_head d_head) -> n_head d_head", + n_head=cfg.n_heads, + ) + state_dict[f"blocks.{l}.attn.b_Q"] = b_Q + state_dict[f"blocks.{l}.attn.b_K"] = b_K + state_dict[f"blocks.{l}.attn.b_V"] = b_V + + W_O = model.h[l].attn.c_proj.weight + W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_O"] = W_O + + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.ln2.w"] = model.h[l].ln_2.weight + + state_dict[f"blocks.{l}.mlp.W_in"] = model.h[l].mlp.w1.weight.T + state_dict[f"blocks.{l}.mlp.W_gate"] = model.h[l].mlp.w2.weight.T + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.mlp.W_out"] = model.h[l].mlp.c_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict["ln_final.w"] = model.ln_f.weight + + state_dict["unembed.W_U"] = qwen.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) + + return state_dict + + def convert_opt_weights(opt, cfg: HookedTransformerConfig): state_dict = {} diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index fcf4fcf3a..c3fe2963e 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -1083,7 +1083,9 @@ def get_tokenizer_with_bos(tokenizer): tokenizer_with_bos = tokenizer else: tokenizer_with_bos = AutoTokenizer.from_pretrained( - pretrained_model_name_or_path, add_bos_token=True, **init_kwargs + pretrained_model_name_or_path, + add_bos_token=True, + **init_kwargs, ) return tokenizer_with_bos From 5754a0b6ed3cc13ca4c92560519d0d8b78fb9541 Mon Sep 17 00:00:00 2001 From: adamkarvonen <85900742+adamkarvonen@users.noreply.github.com> Date: Tue, 16 Jan 2024 17:33:13 -0600 Subject: [PATCH 26/73] Add a function to convert nanogpt weights (#475) * Add a function to convert nanogpt weights * Remove need for bias parameter --- docs/source/conf.py | 1 + transformer_lens/loading_from_pretrained.py | 101 ++++++++++++++++++++ 2 files changed, 102 insertions(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index af38914c0..9308ea2b2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -80,6 +80,7 @@ "convert_gptj_weights", "convert_llama_weights", "convert_mingpt_weights", + "convert_nanogpt_weights", "convert_neel_solu_old_weights", "convert_neo_weights", "convert_neox_weights", diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 56c23955a..8e94fd1da 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1721,6 +1721,107 @@ def convert_mingpt_weights(old_state_dict, cfg: HookedTransformerConfig): return state_dict +def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig): + """For https://github.com/karpathy/nanoGPT + There are two complications with converting nanogpt models: + The first is that some state dicts have an unwanted prefix on keys that needs to be removed. + The second is that the models can be saved with or without bias. By default, there + is no bias. This function can handle both cases.""" + # Nanogpt models saved after torch.compile() have this unwanted prefix + # This is a simple way to remove it + unwanted_prefix = "_orig_mod." + for k, v in list(old_state_dict.items()): + if k.startswith(unwanted_prefix): + old_state_dict[k[len(unwanted_prefix) :]] = old_state_dict.pop(k) + + new_state_dict = {} + new_state_dict["pos_embed.W_pos"] = old_state_dict["transformer.wpe.weight"] + new_state_dict["embed.W_E"] = old_state_dict["transformer.wte.weight"] + + new_state_dict["ln_final.w"] = old_state_dict["transformer.ln_f.weight"] + new_state_dict["ln_final.b"] = torch.zeros_like( + old_state_dict["transformer.ln_f.weight"] + ) + new_state_dict["unembed.W_U"] = old_state_dict["lm_head.weight"].T + + bias = False + if "transformer.ln_f.bias" in old_state_dict: + bias = True + new_state_dict["ln_final.b"] = old_state_dict["transformer.ln_f.bias"] + + for layer in range(cfg.n_layers): + layer_key = f"transformer.h.{layer}" + + new_state_dict[f"blocks.{layer}.ln1.w"] = old_state_dict[ + f"{layer_key}.ln_1.weight" + ] + # A bias of zeros is required for folding layer norm + new_state_dict[f"blocks.{layer}.ln1.b"] = torch.zeros_like( + old_state_dict[f"{layer_key}.ln_1.weight"] + ) + new_state_dict[f"blocks.{layer}.ln2.w"] = old_state_dict[ + f"{layer_key}.ln_2.weight" + ] + new_state_dict[f"blocks.{layer}.ln2.b"] = torch.zeros_like( + old_state_dict[f"{layer_key}.ln_2.weight"] + ) + + W = old_state_dict[f"{layer_key}.attn.c_attn.weight"] + W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0) + W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads) + W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads) + W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads) + new_state_dict[f"blocks.{layer}.attn.W_Q"] = W_Q + new_state_dict[f"blocks.{layer}.attn.W_K"] = W_K + new_state_dict[f"blocks.{layer}.attn.W_V"] = W_V + + W_O = old_state_dict[f"{layer_key}.attn.c_proj.weight"] + W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) + new_state_dict[f"blocks.{layer}.attn.W_O"] = W_O + + new_state_dict[f"blocks.{layer}.mlp.W_in"] = old_state_dict[ + f"{layer_key}.mlp.c_fc.weight" + ].T + new_state_dict[f"blocks.{layer}.mlp.W_out"] = old_state_dict[ + f"{layer_key}.mlp.c_proj.weight" + ].T + + if bias: + new_state_dict[f"blocks.{layer}.ln1.b"] = old_state_dict[ + f"{layer_key}.ln_1.bias" + ] + new_state_dict[f"blocks.{layer}.ln2.b"] = old_state_dict[ + f"{layer_key}.ln_2.bias" + ] + new_state_dict[f"blocks.{layer}.mlp.b_in"] = old_state_dict[ + f"{layer_key}.mlp.c_fc.bias" + ] + new_state_dict[f"blocks.{layer}.mlp.b_out"] = old_state_dict[ + f"{layer_key}.mlp.c_proj.bias" + ] + + B = old_state_dict[f"{layer_key}.attn.c_attn.bias"] + B_Q, B_K, B_V = torch.tensor_split(B, 3, dim=0) + B_Q = einops.rearrange(B_Q, "(i h)->i h", i=cfg.n_heads) + B_K = einops.rearrange(B_K, "(i h)->i h", i=cfg.n_heads) + B_V = einops.rearrange(B_V, "(i h)->i h", i=cfg.n_heads) + new_state_dict[f"blocks.{layer}.attn.b_Q"] = B_Q + new_state_dict[f"blocks.{layer}.attn.b_K"] = B_K + new_state_dict[f"blocks.{layer}.attn.b_V"] = B_V + new_state_dict[f"blocks.{layer}.attn.b_O"] = old_state_dict[ + f"{layer_key}.attn.c_proj.bias" + ] + + new_state_dict[f"blocks.{layer}.mlp.b_in"] = old_state_dict[ + f"{layer_key}.mlp.c_fc.bias" + ].T + new_state_dict[f"blocks.{layer}.mlp.b_out"] = old_state_dict[ + f"{layer_key}.mlp.c_proj.bias" + ].T + + return new_state_dict + + def convert_bert_weights(bert, cfg: HookedTransformerConfig): embeddings = bert.bert.embeddings state_dict = { From 535fadf020233c52a12ab1a980fcee80d4c4bd2c Mon Sep 17 00:00:00 2001 From: yuheng huang <32429436+YuhengHuang42@users.noreply.github.com> Date: Wed, 17 Jan 2024 08:51:09 -0700 Subject: [PATCH 27/73] Add support for CodeLlama-7b (#469) * Add Support for CodeLlama-7b * Reformat --------- Co-authored-by: Neel Nanda --- transformer_lens/HookedTransformerConfig.py | 1 + transformer_lens/components.py | 5 ++- transformer_lens/loading_from_pretrained.py | 35 +++++++++++++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index de12d115f..d54b785d7 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -197,6 +197,7 @@ class HookedTransformerConfig: dtype: torch.dtype = torch.float32 tokenizer_prepends_bos: Optional[bool] = None post_embedding_ln: bool = False + rotary_base: int = 10000 trust_remote_code: bool = False rotary_adjacent_pairs: bool = False diff --git a/transformer_lens/components.py b/transformer_lens/components.py index 114b7f63d..b8926cb43 100644 --- a/transformer_lens/components.py +++ b/transformer_lens/components.py @@ -478,7 +478,10 @@ def __init__( self.hook_rot_k = HookPoint() self.hook_rot_q = HookPoint() sin, cos = self.calculate_sin_cos_rotary( - self.cfg.rotary_dim, self.cfg.n_ctx, dtype=self.cfg.dtype + self.cfg.rotary_dim, + self.cfg.n_ctx, + base=self.cfg.rotary_base, + dtype=self.cfg.dtype, ) self.register_buffer("rotary_sin", sin) self.register_buffer("rotary_cos", cos) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 8e94fd1da..832d8c514 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -111,6 +111,9 @@ "llama-13b-hf", "llama-30b-hf", "llama-65b-hf", + "CodeLlama-7b-hf", + "CodeLlama-7b-Python-hf", + "CodeLlama-7b-Instruct-hf", "Llama-2-7b-hf", "Llama-2-7b-chat-hf", "Llama-2-13b-hf", @@ -466,6 +469,15 @@ "llama-13b-hf": ["llama-13b"], "llama-30b-hf": ["llama-30b"], "llama-65b-hf": ["llama-65b"], + "CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"], + "CodeLlama-7b-Python-hf": [ + "CodeLlama-7b-python", + "codellama/CodeLlama-7b-Python-hf", + ], + "CodeLlama-7b-Instruct-hf": [ + "CodeLlama-7b-instruct", + "codellama/CodeLlama-7b-Instruct-hf", + ], "Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"], "Llama-2-7b-chat-hf": ["Llama-2-7b-chat", "meta-llama/Llama-2-7b-chat-hf"], "Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"], @@ -585,6 +597,29 @@ def convert_hf_model_config(model_name: str, **kwargs): "final_rms": True, "gated_mlp": True, } + elif official_model_name.startswith( + "CodeLlama-7b" + ): # same architecture CodeLlama and Llama-2 + cfg_dict = { + "d_model": 4096, + "d_head": 4096 // 32, + "n_heads": 32, + "d_mlp": 11008, + "n_layers": 32, + "n_ctx": 4096, + "eps": 1e-5, + "d_vocab": 32016, + "act_fn": "silu", + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "rotary_dim": 4096 // 32, + "final_rms": True, + "gated_mlp": True, + "rotary_base": 1000000, + } + if "python" in official_model_name.lower(): + # The vocab size of python version of CodeLlama-7b is 32000 + cfg_dict["d_vocab"] = 32000 elif official_model_name.startswith( ("llama-13b", "Llama-2-13b") ): # same architecture for LLaMA and Llama-2 From 33222e5da2c4fcef558f7d207daedf2eccf80367 Mon Sep 17 00:00:00 2001 From: Andy Arditi Date: Wed, 17 Jan 2024 14:34:50 -0800 Subject: [PATCH 28/73] Make LLaMA 2 loadable directly from HF (#458) --------- Co-authored-by: Alan <41682961+alan-cooney@users.noreply.github.com> --- demos/LLaMA.ipynb | 272 +++++++++----------- transformer_lens/HookedTransformer.py | 8 +- transformer_lens/loading_from_pretrained.py | 40 ++- 3 files changed, 159 insertions(+), 161 deletions(-) diff --git a/demos/LLaMA.ipynb b/demos/LLaMA.ipynb index 9df019d7c..9e9f428e6 100644 --- a/demos/LLaMA.ipynb +++ b/demos/LLaMA.ipynb @@ -14,73 +14,40 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# LLaMA and Llama-2 in TransformerLens\n", - "\n", - "This demo requires `transformers` version 4.31.0 (which adds Llama-2 support). This tutorial has part a) for LLaMA and b) for Llama-2. Currently the only Llama-2 support is the 7B chat model, as this notebook is being tested.\n", - "\n", - "Steps to run this demo:\n", - "\n", - "1a. Get LLaMA weights here: https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform\n", - "\n", - "1b. Get Llama-2 weights here: https://ai.meta.com/resources/models-and-libraries/llama-downloads/\n", - "\n", - "2a. Convert the official weights to huggingface. \n", - "\n", - "```bash\n", - "python src/transformers/models/llama/convert_llama_weights_to_hf.py \\\n", - " --input_dir /path/to/downloaded/llama/weights \\\n", - " --model_size 7B \\\n", - " --output_dir /llama/weights/directory/\n", - "```\n", - "\n", - "2b. Same step for Llama-2, we'll use `7Bf` the 7B chat version\n", - "\n", - "```bash\n", - "python src/transformers/models/llama/convert_llama_weights_to_hf.py \\\n", - " --input_dir /path/to/downloaded/llama-2/weights \\\n", - " --model_size 7Bf \\\n", - " --output_dir /llama/weights/directory/\n", - "```\n", - "\n", - "Note: this didn't work for Arthur by default (even though HF doesn't seem to show this anywhere). I had to change this line of my pip installed `src/transformers/models/llama/convert_llama_weights_to_hf.py` file (which was found at `/opt/conda/envs/arthurenv/lib/python3.10/site-packages/transformers/models/llama/convert_llama_weights_to_hf.py`) from \n", - "\n", - "`input_base_path=os.path.join(args.input_dir, args.model_size),` to `input_base_path=os.path.join(args.input_dir),`\n", - "\n", - "3. Change the ```MODEL_PATH``` variable in the cell below to where the converted weights are stored." + "# LLaMA and Llama-2 in TransformerLens" ] }, { - "cell_type": "code", - "execution_count": 1, + "attachments": {}, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "from typing import Literal\n", - "\n", - "MODE: Literal[\"LLaMA\", \"Llama-2\"] = \"Llama-2\" # change to LLaMA for original LLaMA\n", - "MODEL_PATH: str = \"\" # Set the path to the /llama/weights/directory/ that you used in the command" + "## Setup (skip)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [], - "source": [ - "!pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n", + "Requirement already satisfied: sentencepiece in /root/TransformerLens/.venv/lib/python3.10/site-packages (0.1.99)\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], "source": [ - "## Setup (skip)" + "%pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8\n", + "%pip install sentencepiece # Llama tokenizer requires sentencepiece" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -94,9 +61,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/tmp/ipykernel_20722/410710250.py:21: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + "/tmp/ipykernel_16979/572068249.py:21: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", " ipython.magic(\"load_ext autoreload\")\n", - "/tmp/ipykernel_20722/410710250.py:22: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + "/tmp/ipykernel_16979/572068249.py:22: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", " ipython.magic(\"autoreload 2\")\n" ] } @@ -128,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -153,45 +120,25 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Import stuff\n", "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import torch.optim as optim\n", - "import numpy as np\n", - "import einops\n", - "from fancy_einsum import einsum\n", "import tqdm.auto as tqdm\n", - "from tqdm import tqdm\n", - "import random\n", - "from pathlib import Path\n", "import plotly.express as px\n", - "from torch.utils.data import DataLoader\n", - "\n", - "from torchtyping import TensorType as TT\n", - "from typing import List, Union, Optional\n", - "from jaxtyping import Float, Int\n", - "from functools import partial\n", - "import copy\n", "\n", - "import itertools\n", - "from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n", - "import dataclasses\n", - "import datasets\n", - "from IPython.display import HTML\n", - "# import circuitsvis as cv\n", + "from transformers import LlamaForCausalLM, LlamaTokenizer\n", + "from tqdm import tqdm\n", + "from jaxtyping import Float\n", "\n", "import transformer_lens\n", "import transformer_lens.utils as utils\n", "from transformer_lens.hook_points import (\n", - " HookedRootModule,\n", " HookPoint,\n", ") # Hooking utilities\n", - "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache\n", + "from transformer_lens import HookedTransformer\n", "\n", "torch.set_grad_enabled(False)\n", "\n", @@ -208,29 +155,95 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "## Loading model" + "## Loading LLaMA" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "LLaMA weights are not available on HuggingFace, so you'll need to download and convert them\n", + "manually:\n", + "\n", + "1. Get LLaMA weights here: https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform\n", + "\n", + "2. Convert the official weights to huggingface:\n", + "\n", + "```bash\n", + "python src/transformers/models/llama/convert_llama_weights_to_hf.py \\\n", + " --input_dir /path/to/downloaded/llama/weights \\\n", + " --model_size 7B \\\n", + " --output_dir /llama/weights/directory/\n", + "```\n", + "\n", + "Note: this didn't work for Arthur by default (even though HF doesn't seem to show this anywhere). I\n", + "had to change this\n", + "line of my pip installed `src/transformers/models/llama/convert_llama_weights_to_hf.py` file (which\n", + "was found at\n", + "`/opt/conda/envs/arthurenv/lib/python3.10/site-packages/transformers/models/llama/convert_llama_weights_to_hf.py`)\n", + "from `input_base_path=os.path.join(args.input_dir, args.model_size),` to `input_base_path=os.path.join(args.input_dir),`\n", + "\n", + "3. Change the ```MODEL_PATH``` variable in the cell below to where the converted weights are stored." ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_PATH=''\n", + "\n", + "tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)\n", + "hf_model = LlamaForCausalLM.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True)\n", + "\n", + "model = HookedTransformer.from_pretrained(\"llama-7b\", hf_model=hf_model, device=\"cpu\", fold_ln=False, center_writing_weights=False, center_unembed=False, tokenizer=tokenizer)\n", + "\n", + "model = model.to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "model.generate(\"The capital of Germany is\", max_new_tokens=20, temperature=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading LLaMA-2\n", + "LLaMA-2 is hosted on HuggingFace, but gated by login.\n", + "\n", + "Before running the notebook, log in to HuggingFace via the cli on your machine:\n", + "```bash\n", + "transformers-cli login\n", + "```\n", + "This will cache your HuggingFace credentials, and enable you to download LLaMA-2." + ] + }, + { + "cell_type": "code", + "execution_count": 5, "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "You are using the legacy behaviour of the . This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565\n" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1821773a30ad4a56960ccae34e8e6a3d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00The capital of Germany is Berlin. Berlin is the largest city in Germany and is known for its rich history, cultural attractions'" + "'The capital of Germany is Berlin. Berlin is the largest city in Germany and is known for its rich history, cultural attractions'" ] }, - "execution_count": 7, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# Loading on CPU is cheapest memory wise in transformer_lens \n", - "if MODE == \"LLaMA\":\n", - " model = HookedTransformer.from_pretrained(\"llama-7b\", hf_model=hf_model, device=\"cpu\", fold_ln=False, center_writing_weights=False, center_unembed=False, tokenizer=tokenizer)\n", + "LLAMA_2_7B_CHAT_PATH = \"meta-llama/Llama-2-7b-chat-hf\"\n", + "\n", + "tokenizer = LlamaTokenizer.from_pretrained(LLAMA_2_7B_CHAT_PATH)\n", + "hf_model = LlamaForCausalLM.from_pretrained(LLAMA_2_7B_CHAT_PATH, low_cpu_mem_usage=True)\n", + "\n", + "model = HookedTransformer.from_pretrained(LLAMA_2_7B_CHAT_PATH, device=\"cpu\", fold_ln=False, center_writing_weights=False, center_unembed=False)\n", "\n", - "elif MODE == \"Llama-2\":\n", - " model = HookedTransformer.from_pretrained(\"meta-llama/Llama-2-7b-chat-hf\", hf_model=hf_model, device=\"cpu\", fold_ln=False, center_writing_weights=False, center_unembed=False, tokenizer=tokenizer)\n", - " \n", "model = model.to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model.generate(\"The capital of Germany is\", max_new_tokens=20, temperature=0)" ] @@ -321,22 +309,15 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - " 0%| | 0/4 [00:00\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 11, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ @@ -419,7 +399,7 @@ "llama_str_tokens = model.to_str_tokens(llama_text)\n", "\n", "print(\"Layer 0 Head Attention Patterns:\")\n", - "cv.attention.attention_patterns(tokens=llama_str_tokens, attention=attention_pattern)" + "display(cv.attention.attention_patterns(tokens=llama_str_tokens, attention=attention_pattern))" ] }, { @@ -432,7 +412,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -440,8 +420,8 @@ "output_type": "stream", "text": [ "Shape of the value tensor: torch.Size([1, 34, 32, 128])\n", - "Original Loss: 2.933\n", - "Ablated Loss: 2.881\n" + "Original Loss: 2.931\n", + "Ablated Loss: 2.879\n" ] } ], @@ -490,7 +470,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.10.13" }, "orig_nbformat": 4, "vscode": { diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 2165de1e2..88c53ad7e 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -37,6 +37,7 @@ ) from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookedRootModule, HookPoint +from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES # Note - activation cache is used with run_with_cache, past_key_value_caching is used for # generation. @@ -118,9 +119,10 @@ def __init__( self.set_tokenizer(tokenizer, default_padding_side=default_padding_side) elif self.cfg.tokenizer_name is not None: # If we have a tokenizer name, we can load it from HuggingFace - if "llama" in self.cfg.tokenizer_name.lower(): - # llama tokenizer requires special handling - logging.warning("LLaMA tokenizer not loaded. Please load manually.") + if self.cfg.tokenizer_name in NON_HF_HOSTED_MODEL_NAMES: + logging.warning( + f"{self.cfg.tokenizer_name} tokenizer not loaded. Please load manually." + ) else: self.set_tokenizer( AutoTokenizer.from_pretrained( diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 832d8c514..994a2b3a2 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -111,13 +111,13 @@ "llama-13b-hf", "llama-30b-hf", "llama-65b-hf", + "meta-llama/Llama-2-7b-hf", + "meta-llama/Llama-2-7b-chat-hf", + "meta-llama/Llama-2-13b-hf", + "meta-llama/Llama-2-13b-chat-hf", "CodeLlama-7b-hf", "CodeLlama-7b-Python-hf", "CodeLlama-7b-Instruct-hf", - "Llama-2-7b-hf", - "Llama-2-7b-chat-hf", - "Llama-2-13b-hf", - "Llama-2-13b-chat-hf", # TODO Llama-2-70b-hf requires Grouped-Query Attention, see the paper https://arxiv.org/pdf/2307.09288.pdf "Baidicoot/Othello-GPT-Transformer-Lens", "bert-base-cased", @@ -469,6 +469,16 @@ "llama-13b-hf": ["llama-13b"], "llama-30b-hf": ["llama-30b"], "llama-65b-hf": ["llama-65b"], + "meta-llama/Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"], + "meta-llama/Llama-2-7b-chat-hf": [ + "Llama-2-7b-chat", + "meta-llama/Llama-2-7b-chat-hf", + ], + "meta-llama/Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"], + "meta-llama/Llama-2-13b-chat-hf": [ + "Llama-2-13b-chat", + "meta-llama/Llama-2-13b-chat-hf", + ], "CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"], "CodeLlama-7b-Python-hf": [ "CodeLlama-7b-python", @@ -478,10 +488,6 @@ "CodeLlama-7b-instruct", "codellama/CodeLlama-7b-Instruct-hf", ], - "Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"], - "Llama-2-7b-chat-hf": ["Llama-2-7b-chat", "meta-llama/Llama-2-7b-chat-hf"], - "Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"], - "Llama-2-13b-chat-hf": ["Llama-2-13b-chat", "meta-llama/Llama-2-13b-chat-hf"], # TODO Llama-2-70b-hf requires Grouped-Query Attention, see the paper https://arxiv.org/pdf/2307.09288.pdf "Baidicoot/Othello-GPT-Transformer-Lens": ["othello-gpt"], "roneneldan/TinyStories-1M": ["tiny-stories-1M"], @@ -525,6 +531,14 @@ } """Model aliases for models on HuggingFace.""" +NON_HF_HOSTED_MODEL_NAMES = [ + "llama-7b-hf", + "llama-13b-hf", + "llama-30b-hf", + "llama-65b-hf", +] +"""Official model names for models that not hosted on HuggingFace.""" + # Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases DEFAULT_MODEL_ALIASES = [ MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name @@ -578,7 +592,7 @@ def convert_hf_model_config(model_name: str, **kwargs): else: architecture = "LlamaForCausalLM" if official_model_name.startswith( - ("llama-7b", "Llama-2-7b") + ("llama-7b", "meta-llama/Llama-2-7b") ): # same architecture for LLaMA and Llama-2 cfg_dict = { "d_model": 4096, @@ -621,7 +635,7 @@ def convert_hf_model_config(model_name: str, **kwargs): # The vocab size of python version of CodeLlama-7b is 32000 cfg_dict["d_vocab"] = 32000 elif official_model_name.startswith( - ("llama-13b", "Llama-2-13b") + ("llama-13b", "meta-llama/Llama-2-13b") ): # same architecture for LLaMA and Llama-2 cfg_dict = { "d_model": 5120, @@ -1141,8 +1155,10 @@ def get_pretrained_state_dict( f"Checkpoints for model {official_model_name} are not supported" ) elif hf_model is None: - if "llama" in official_model_name.lower(): - raise NotImplementedError("Must pass in hf_model for LLaMA models") + if official_model_name in NON_HF_HOSTED_MODEL_NAMES: + raise NotImplementedError( + "Model not hosted on HuggingFace, must pass in hf_model" + ) elif "bert" in official_model_name: hf_model = BertForPreTraining.from_pretrained( official_model_name, torch_dtype=dtype, **kwargs From 6867800143486644b928e6ff8b0829ff40a65e7a Mon Sep 17 00:00:00 2001 From: Artyom K Date: Thu, 18 Jan 2024 01:37:34 +0300 Subject: [PATCH 29/73] Fixe #371: Resolve issues where LLama will not load on CUDA (#461) --- transformer_lens/loading_from_pretrained.py | 22 ++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 994a2b3a2..aa67e653d 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1484,20 +1484,22 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.attn.W_V"] = W_V state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( - cfg.n_heads, cfg.d_head, dtype=cfg.dtype + cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device ) state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros( - cfg.n_heads, cfg.d_head, dtype=cfg.dtype + cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device ) state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros( - cfg.n_heads, cfg.d_head, dtype=cfg.dtype + cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device ) W_O = llama.model.layers[l].self_attn.o_proj.weight W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) state_dict[f"blocks.{l}.attn.W_O"] = W_O - state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros( + cfg.d_model, dtype=cfg.dtype, device=cfg.device + ) state_dict[f"blocks.{l}.ln2.w"] = llama.model.layers[ l @@ -1507,17 +1509,23 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.mlp.W_gate"] = llama.model.layers[ l ].mlp.gate_proj.weight.T - state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros( + cfg.d_mlp, dtype=cfg.dtype, device=cfg.device + ) state_dict[f"blocks.{l}.mlp.W_out"] = llama.model.layers[ l ].mlp.down_proj.weight.T - state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros( + cfg.d_model, dtype=cfg.dtype, device=cfg.device + ) state_dict["ln_final.w"] = llama.model.norm.weight state_dict["unembed.W_U"] = llama.lm_head.weight.T - state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) + state_dict["unembed.b_U"] = torch.zeros( + cfg.d_vocab, dtype=cfg.dtype, device=cfg.device + ) return state_dict From a5147baea899f16f0db34b1a7b4e3464d3fd4b30 Mon Sep 17 00:00:00 2001 From: Jacob Xiaochen Li <35388161+SeuperHakkerJa@users.noreply.github.com> Date: Thu, 18 Jan 2024 06:38:45 +0800 Subject: [PATCH 30/73] Add support for larger Bloom models (up to 7b) (#447) --- transformer_lens/loading_from_pretrained.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index aa67e653d..10345a069 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -140,6 +140,10 @@ "stabilityai/stablelm-tuned-alpha-3b", "stabilityai/stablelm-tuned-alpha-7b", "bigscience/bloom-560m", + "bigscience/bloom-1b1", + "bigscience/bloom-1b7", + "bigscience/bloom-3b", + "bigscience/bloom-7b1", "bigcode/santacoder", "Qwen/Qwen-1_8B", "Qwen/Qwen-7B", @@ -521,6 +525,10 @@ "stablelm-tuned-7b", ], "bigscience/bloom-560m": ["bloom-560m"], + "bigscience/bloom-1b1": ["bloom-1b1"], + "bigscience/bloom-1b7": ["bloom-1b7"], + "bigscience/bloom-3b": ["bloom-3b"], + "bigscience/bloom-7b1": ["bloom-7b1"], "bigcode/santacoder": ["santacoder"], "Qwen/Qwen-1_8B": ["qwen-1.8b"], "Qwen/Qwen-7B": ["qwen-7b"], @@ -1956,10 +1964,8 @@ def convert_bloom_weights(bloom, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.ln1.w"] = bloom.transformer.h[l].input_layernorm.weight state_dict[f"blocks.{l}.ln1.b"] = bloom.transformer.h[l].input_layernorm.bias - # Bloom attn weight is stored as a fused matrx. BloomAttn: Linear(in=1024, out=3072) - # The .weight returned matrix will be in shape (3072, 1024) W = bloom.transformer.h[l].self_attention.query_key_value.weight - # First transpose -> (1024, 3072), then split into (d_model, n_heads, 3, d_head) + W_split = W.T.reshape(cfg.d_model, cfg.n_heads, 3, cfg.d_head) W_Q, W_K, W_V = W_split[..., 0, :], W_split[..., 1, :], W_split[..., 2, :] @@ -2004,7 +2010,7 @@ def convert_bloom_weights(bloom, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.mlp.b_out"] = bloom.transformer.h[ l ].mlp.dense_4h_to_h.bias - state_dict["unembed.W_U"] = bloom.lm_head.weight.T # transpose to match shape + state_dict["unembed.W_U"] = bloom.lm_head.weight.T state_dict["ln_final.w"] = bloom.transformer.ln_f.weight state_dict["ln_final.b"] = bloom.transformer.ln_f.bias From 11edb28fb463b1aa75191568101a6d48bdbd7276 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20Hofst=C3=A4tter?= Date: Mon, 22 Jan 2024 03:08:20 -0800 Subject: [PATCH 31/73] Add mistral 7b support (#443) * add mistral model name and alias * add code for converting mistral config to hooked transformer config * add function for converting mistral weights * add GroupedQueryAttention * add abstract base class for attention * adapt keyvaluecache if grouped query attention is used * fix fold_value_biases when using grouped query attention * Add unit test for grouped query attention * Add demo notebook for Mistral * fix formatting * add documentation for grouped query attention * update lock file * use Union instead of | for union types * hardcode mistral config so building docs works with older versions of transformers * don't set final_rms in Mistral config * make Mistral-7b's alias name consistent with other models * update main demo notebook * require transformers>=3.34 * improve docstrings and clarify test name for grouped query attention * remove Mistral demo * fix docstring format --- demos/Main_Demo.ipynb | 46 +- poetry.lock | 1471 +++++++++++-------- pyproject.toml | 3 +- tests/unit/test_grouped_query_attention.py | 82 ++ transformer_lens/HookedTransformer.py | 20 +- transformer_lens/HookedTransformerConfig.py | 3 + transformer_lens/components.py | 395 ++++- transformer_lens/loading_from_pretrained.py | 95 +- transformer_lens/past_key_value_caching.py | 7 +- 9 files changed, 1398 insertions(+), 724 deletions(-) create mode 100644 tests/unit/test_grouped_query_attention.py diff --git a/demos/Main_Demo.ipynb b/demos/Main_Demo.ipynb index d5d524c76..c871a6bd8 100644 --- a/demos/Main_Demo.ipynb +++ b/demos/Main_Demo.ipynb @@ -45,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 292, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -80,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 293, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -103,32 +103,28 @@ }, { "cell_type": "code", - "execution_count": 294, + "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 294, - "metadata": { - "text/html": { - "Content-Type": "text/html" - } - }, + "execution_count": 13, + "metadata": {}, "output_type": "execute_result" } ], @@ -140,7 +136,7 @@ }, { "cell_type": "code", - "execution_count": 295, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -158,7 +154,7 @@ }, { "cell_type": "code", - "execution_count": 296, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -179,16 +175,16 @@ }, { "cell_type": "code", - "execution_count": 297, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 297, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -254,7 +250,7 @@ }, { "cell_type": "code", - "execution_count": 299, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -263,7 +259,7 @@ }, { "cell_type": "code", - "execution_count": 300, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -1210,7 +1206,7 @@ }, { "cell_type": "code", - "execution_count": 314, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -1218,13 +1214,13 @@ "output_type": "stream", "text": [ "blocks.0.attn.W_Q torch.Size([12, 768, 64])\n", - "blocks.0.attn.W_K torch.Size([12, 768, 64])\n", - "blocks.0.attn.W_V torch.Size([12, 768, 64])\n", "blocks.0.attn.W_O torch.Size([12, 64, 768])\n", "blocks.0.attn.b_Q torch.Size([12, 64])\n", + "blocks.0.attn.b_O torch.Size([768])\n", + "blocks.0.attn.W_K torch.Size([12, 768, 64])\n", + "blocks.0.attn.W_V torch.Size([12, 768, 64])\n", "blocks.0.attn.b_K torch.Size([12, 64])\n", "blocks.0.attn.b_V torch.Size([12, 64])\n", - "blocks.0.attn.b_O torch.Size([768])\n", "blocks.0.mlp.W_in torch.Size([768, 3072])\n", "blocks.0.mlp.b_in torch.Size([3072])\n", "blocks.0.mlp.W_out torch.Size([3072, 768])\n", @@ -1247,7 +1243,7 @@ }, { "cell_type": "code", - "execution_count": 315, + "execution_count": 20, "metadata": {}, "outputs": [ { diff --git a/poetry.lock b/poetry.lock index c95a1e5f2..5b7b84584 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,9 +1,10 @@ -# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. +# This file is automatically @generated by Poetry and should not be changed by hand. [[package]] name = "accelerate" version = "0.25.0" description = "Accelerate" +category = "main" optional = false python-versions = ">=3.8.0" files = [ @@ -34,6 +35,7 @@ testing = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "parameterized", name = "aiohttp" version = "3.9.1" description = "Async http client/server framework (asyncio)" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -130,6 +132,7 @@ speedups = ["Brotli", "aiodns", "brotlicffi"] name = "aiosignal" version = "1.3.1" description = "aiosignal: a list of registered asynchronous callbacks" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -144,6 +147,7 @@ frozenlist = ">=1.1.0" name = "alabaster" version = "0.7.13" description = "A configurable sidebar-enabled Sphinx theme" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -153,19 +157,21 @@ files = [ [[package]] name = "anyio" -version = "4.1.0" +version = "4.2.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" +category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "anyio-4.1.0-py3-none-any.whl", hash = "sha256:56a415fbc462291813a94528a779597226619c8e78af7de0507333f700011e5f"}, - {file = "anyio-4.1.0.tar.gz", hash = "sha256:5a0bec7085176715be77df87fc66d6c9d70626bd752fcc85f57cdbee5b3760da"}, + {file = "anyio-4.2.0-py3-none-any.whl", hash = "sha256:745843b39e829e108e518c489b31dc757de7d2131d53fac32bd8df268227bfee"}, + {file = "anyio-4.2.0.tar.gz", hash = "sha256:e1875bb4b4e2de1669f4bc7869b6d3f54231cdced71605e6e64c9be77e3be50f"}, ] [package.dependencies] exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} idna = ">=2.8" sniffio = ">=1.1" +typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} [package.extras] doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] @@ -176,6 +182,7 @@ trio = ["trio (>=0.23)"] name = "appdirs" version = "1.4.4" description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "main" optional = false python-versions = "*" files = [ @@ -187,6 +194,7 @@ files = [ name = "appnope" version = "0.1.3" description = "Disable App Nap on macOS >= 10.9" +category = "dev" optional = false python-versions = "*" files = [ @@ -198,6 +206,7 @@ files = [ name = "argon2-cffi" version = "23.1.0" description = "Argon2 for Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -218,6 +227,7 @@ typing = ["mypy"] name = "argon2-cffi-bindings" version = "21.2.0" description = "Low-level CFFI bindings for Argon2" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -255,6 +265,7 @@ tests = ["pytest"] name = "arrow" version = "1.3.0" description = "Better dates & times for Python" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -268,12 +279,13 @@ types-python-dateutil = ">=2.8.10" [package.extras] doc = ["doc8", "sphinx (>=7.0.0)", "sphinx-autobuild", "sphinx-autodoc-typehints", "sphinx_rtd_theme (>=1.3.0)"] -test = ["dateparser (==1.*)", "pre-commit", "pytest", "pytest-cov", "pytest-mock", "pytz (==2021.1)", "simplejson (==3.*)"] +test = ["dateparser (>=1.0.0,<2.0.0)", "pre-commit", "pytest", "pytest-cov", "pytest-mock", "pytz (==2021.1)", "simplejson (>=3.0.0,<4.0.0)"] [[package]] name = "asttokens" version = "2.4.1" description = "Annotate AST trees with source code positions" +category = "dev" optional = false python-versions = "*" files = [ @@ -292,6 +304,7 @@ test = ["astroid (>=1,<2)", "astroid (>=2,<4)", "pytest"] name = "async-lru" version = "2.0.4" description = "Simple LRU cache for asyncio" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -306,6 +319,7 @@ typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} name = "async-timeout" version = "4.0.3" description = "Timeout context manager for asyncio programs" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -315,36 +329,38 @@ files = [ [[package]] name = "attrs" -version = "23.1.0" +version = "23.2.0" description = "Classes Without Boilerplate" +category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "attrs-23.1.0-py3-none-any.whl", hash = "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"}, - {file = "attrs-23.1.0.tar.gz", hash = "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015"}, + {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, + {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"}, ] [package.extras] cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] -dev = ["attrs[docs,tests]", "pre-commit"] +dev = ["attrs[tests]", "pre-commit"] docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] tests = ["attrs[tests-no-zope]", "zope-interface"] -tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] +tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] [[package]] name = "babel" -version = "2.13.1" +version = "2.14.0" description = "Internationalization utilities" +category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "Babel-2.13.1-py3-none-any.whl", hash = "sha256:7077a4984b02b6727ac10f1f7294484f737443d7e2e66c5e4380e41a3ae0b4ed"}, - {file = "Babel-2.13.1.tar.gz", hash = "sha256:33e0952d7dd6374af8dbf6768cc4ddf3ccfefc244f9986d4074704f2fbd18900"}, + {file = "Babel-2.14.0-py3-none-any.whl", hash = "sha256:efb1a25b7118e67ce3a259bed20545c29cb68be8ad2c784c83689981b7a57287"}, + {file = "Babel-2.14.0.tar.gz", hash = "sha256:6919867db036398ba21eb5c7a0f6b28ab8cbc3ae7a73a44ebe34ae74a4e7d363"}, ] [package.dependencies] pytz = {version = ">=2015.7", markers = "python_version < \"3.9\""} -setuptools = {version = "*", markers = "python_version >= \"3.12\""} [package.extras] dev = ["freezegun (>=1.0,<2.0)", "pytest (>=6.0)", "pytest-cov"] @@ -353,6 +369,7 @@ dev = ["freezegun (>=1.0,<2.0)", "pytest (>=6.0)", "pytest-cov"] name = "backcall" version = "0.2.0" description = "Specifications for callback functions passed in to an API" +category = "dev" optional = false python-versions = "*" files = [ @@ -364,6 +381,7 @@ files = [ name = "beartype" version = "0.14.1" description = "Unbearably fast runtime type checking in pure Python." +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -382,6 +400,7 @@ test-tox-coverage = ["coverage (>=5.5)"] name = "beautifulsoup4" version = "4.12.2" description = "Screen-scraping library" +category = "dev" optional = false python-versions = ">=3.6.0" files = [ @@ -396,31 +415,48 @@ soupsieve = ">1.2" html5lib = ["html5lib"] lxml = ["lxml"] +[[package]] +name = "better-abc" +version = "0.0.3" +description = "Python ABC plus abstract attributes" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "better-abc-0.0.3.tar.gz", hash = "sha256:a880fd6bc9675da2ec991e8712a555bffa0f12722efed78c739f78343cf989f6"}, + {file = "better_abc-0.0.3-py3-none-any.whl", hash = "sha256:3ae73b473fbeb536a548f542984976e80b821676ae6e18f14e24d8e180647187"}, +] + [[package]] name = "black" -version = "23.11.0" +version = "23.12.1" description = "The uncompromising code formatter." +category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "black-23.11.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dbea0bb8575c6b6303cc65017b46351dc5953eea5c0a59d7b7e3a2d2f433a911"}, - {file = "black-23.11.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:412f56bab20ac85927f3a959230331de5614aecda1ede14b373083f62ec24e6f"}, - {file = "black-23.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d136ef5b418c81660ad847efe0e55c58c8208b77a57a28a503a5f345ccf01394"}, - {file = "black-23.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:6c1cac07e64433f646a9a838cdc00c9768b3c362805afc3fce341af0e6a9ae9f"}, - {file = "black-23.11.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cf57719e581cfd48c4efe28543fea3d139c6b6f1238b3f0102a9c73992cbb479"}, - {file = "black-23.11.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:698c1e0d5c43354ec5d6f4d914d0d553a9ada56c85415700b81dc90125aac244"}, - {file = "black-23.11.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:760415ccc20f9e8747084169110ef75d545f3b0932ee21368f63ac0fee86b221"}, - {file = "black-23.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:58e5f4d08a205b11800332920e285bd25e1a75c54953e05502052738fe16b3b5"}, - {file = "black-23.11.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:45aa1d4675964946e53ab81aeec7a37613c1cb71647b5394779e6efb79d6d187"}, - {file = "black-23.11.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4c44b7211a3a0570cc097e81135faa5f261264f4dfaa22bd5ee2875a4e773bd6"}, - {file = "black-23.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a9acad1451632021ee0d146c8765782a0c3846e0e0ea46659d7c4f89d9b212b"}, - {file = "black-23.11.0-cp38-cp38-win_amd64.whl", hash = "sha256:fc7f6a44d52747e65a02558e1d807c82df1d66ffa80a601862040a43ec2e3142"}, - {file = "black-23.11.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7f622b6822f02bfaf2a5cd31fdb7cd86fcf33dab6ced5185c35f5db98260b055"}, - {file = "black-23.11.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:250d7e60f323fcfc8ea6c800d5eba12f7967400eb6c2d21ae85ad31c204fb1f4"}, - {file = "black-23.11.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5133f5507007ba08d8b7b263c7aa0f931af5ba88a29beacc4b2dc23fcefe9c06"}, - {file = "black-23.11.0-cp39-cp39-win_amd64.whl", hash = "sha256:421f3e44aa67138ab1b9bfbc22ee3780b22fa5b291e4db8ab7eee95200726b07"}, - {file = "black-23.11.0-py3-none-any.whl", hash = "sha256:54caaa703227c6e0c87b76326d0862184729a69b73d3b7305b6288e1d830067e"}, - {file = "black-23.11.0.tar.gz", hash = "sha256:4c68855825ff432d197229846f971bc4d6666ce90492e5b02013bcaca4d9ab05"}, + {file = "black-23.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0aaf6041986767a5e0ce663c7a2f0e9eaf21e6ff87a5f95cbf3675bfd4c41d2"}, + {file = "black-23.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c88b3711d12905b74206227109272673edce0cb29f27e1385f33b0163c414bba"}, + {file = "black-23.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a920b569dc6b3472513ba6ddea21f440d4b4c699494d2e972a1753cdc25df7b0"}, + {file = "black-23.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:3fa4be75ef2a6b96ea8d92b1587dd8cb3a35c7e3d51f0738ced0781c3aa3a5a3"}, + {file = "black-23.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8d4df77958a622f9b5a4c96edb4b8c0034f8434032ab11077ec6c56ae9f384ba"}, + {file = "black-23.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:602cfb1196dc692424c70b6507593a2b29aac0547c1be9a1d1365f0d964c353b"}, + {file = "black-23.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c4352800f14be5b4864016882cdba10755bd50805c95f728011bcb47a4afd59"}, + {file = "black-23.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:0808494f2b2df923ffc5723ed3c7b096bd76341f6213989759287611e9837d50"}, + {file = "black-23.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:25e57fd232a6d6ff3f4478a6fd0580838e47c93c83eaf1ccc92d4faf27112c4e"}, + {file = "black-23.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2d9e13db441c509a3763a7a3d9a49ccc1b4e974a47be4e08ade2a228876500ec"}, + {file = "black-23.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1bd9c210f8b109b1762ec9fd36592fdd528485aadb3f5849b2740ef17e674e"}, + {file = "black-23.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:ae76c22bde5cbb6bfd211ec343ded2163bba7883c7bc77f6b756a1049436fbb9"}, + {file = "black-23.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1fa88a0f74e50e4487477bc0bb900c6781dbddfdfa32691e780bf854c3b4a47f"}, + {file = "black-23.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a4d6a9668e45ad99d2f8ec70d5c8c04ef4f32f648ef39048d010b0689832ec6d"}, + {file = "black-23.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b18fb2ae6c4bb63eebe5be6bd869ba2f14fd0259bda7d18a46b764d8fb86298a"}, + {file = "black-23.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:c04b6d9d20e9c13f43eee8ea87d44156b8505ca8a3c878773f68b4e4812a421e"}, + {file = "black-23.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e1b38b3135fd4c025c28c55ddfc236b05af657828a8a6abe5deec419a0b7055"}, + {file = "black-23.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f0031eaa7b921db76decd73636ef3a12c942ed367d8c3841a0739412b260a54"}, + {file = "black-23.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97e56155c6b737854e60a9ab1c598ff2533d57e7506d97af5481141671abf3ea"}, + {file = "black-23.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:dd15245c8b68fe2b6bd0f32c1556509d11bb33aec9b5d0866dd8e2ed3dba09c2"}, + {file = "black-23.12.1-py3-none-any.whl", hash = "sha256:78baad24af0f033958cad29731e27363183e140962595def56423e626f4bee3e"}, + {file = "black-23.12.1.tar.gz", hash = "sha256:4ce3ef14ebe8d9509188014d96af1c456a910d5b5cbf434a09fef7e024b3d0d5"}, ] [package.dependencies] @@ -434,7 +470,7 @@ typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} [package.extras] colorama = ["colorama (>=0.4.3)"] -d = ["aiohttp (>=3.7.4)"] +d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] @@ -442,6 +478,7 @@ uvloop = ["uvloop (>=0.15.2)"] name = "bleach" version = "6.1.0" description = "An easy safelist-based HTML-sanitizing tool." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -460,6 +497,7 @@ css = ["tinycss2 (>=1.1.0,<1.3)"] name = "certifi" version = "2023.11.17" description = "Python package for providing Mozilla's CA Bundle." +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -471,6 +509,7 @@ files = [ name = "cffi" version = "1.16.0" description = "Foreign Function Interface for Python calling C code." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -535,6 +574,7 @@ pycparser = "*" name = "charset-normalizer" version = "3.3.2" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -634,6 +674,7 @@ files = [ name = "circuitsvis" version = "1.43.2" description = "Mechanistic Interpretability Visualizations" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -666,6 +707,7 @@ triton = {version = "2.1.0", markers = "platform_system == \"Linux\" and platfor name = "click" version = "8.1.7" description = "Composable command line interface toolkit" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -680,6 +722,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -689,13 +732,14 @@ files = [ [[package]] name = "comm" -version = "0.2.0" +version = "0.2.1" description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." +category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "comm-0.2.0-py3-none-any.whl", hash = "sha256:2da8d9ebb8dd7bfc247adaff99f24dce705638a8042b85cb995066793e391001"}, - {file = "comm-0.2.0.tar.gz", hash = "sha256:a517ea2ca28931c7007a7a99c562a0fa5883cfb48963140cf642c41c948498be"}, + {file = "comm-0.2.1-py3-none-any.whl", hash = "sha256:87928485c0dfc0e7976fd89fc1e187023cf587e7c353e4a9b417555b44adf021"}, + {file = "comm-0.2.1.tar.gz", hash = "sha256:0bc91edae1344d39d3661dcbc36937181fdaddb304790458f8b044dbc064b89a"}, ] [package.dependencies] @@ -706,63 +750,64 @@ test = ["pytest"] [[package]] name = "coverage" -version = "7.3.2" +version = "7.4.0" description = "Code coverage measurement for Python" +category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "coverage-7.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d872145f3a3231a5f20fd48500274d7df222e291d90baa2026cc5152b7ce86bf"}, - {file = "coverage-7.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:310b3bb9c91ea66d59c53fa4989f57d2436e08f18fb2f421a1b0b6b8cc7fffda"}, - {file = "coverage-7.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f47d39359e2c3779c5331fc740cf4bce6d9d680a7b4b4ead97056a0ae07cb49a"}, - {file = "coverage-7.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aa72dbaf2c2068404b9870d93436e6d23addd8bbe9295f49cbca83f6e278179c"}, - {file = "coverage-7.3.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:beaa5c1b4777f03fc63dfd2a6bd820f73f036bfb10e925fce067b00a340d0f3f"}, - {file = "coverage-7.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:dbc1b46b92186cc8074fee9d9fbb97a9dd06c6cbbef391c2f59d80eabdf0faa6"}, - {file = "coverage-7.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:315a989e861031334d7bee1f9113c8770472db2ac484e5b8c3173428360a9148"}, - {file = "coverage-7.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d1bc430677773397f64a5c88cb522ea43175ff16f8bfcc89d467d974cb2274f9"}, - {file = "coverage-7.3.2-cp310-cp310-win32.whl", hash = "sha256:a889ae02f43aa45032afe364c8ae84ad3c54828c2faa44f3bfcafecb5c96b02f"}, - {file = "coverage-7.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:c0ba320de3fb8c6ec16e0be17ee1d3d69adcda99406c43c0409cb5c41788a611"}, - {file = "coverage-7.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ac8c802fa29843a72d32ec56d0ca792ad15a302b28ca6203389afe21f8fa062c"}, - {file = "coverage-7.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:89a937174104339e3a3ffcf9f446c00e3a806c28b1841c63edb2b369310fd074"}, - {file = "coverage-7.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e267e9e2b574a176ddb983399dec325a80dbe161f1a32715c780b5d14b5f583a"}, - {file = "coverage-7.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2443cbda35df0d35dcfb9bf8f3c02c57c1d6111169e3c85fc1fcc05e0c9f39a3"}, - {file = "coverage-7.3.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4175e10cc8dda0265653e8714b3174430b07c1dca8957f4966cbd6c2b1b8065a"}, - {file = "coverage-7.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0cbf38419fb1a347aaf63481c00f0bdc86889d9fbf3f25109cf96c26b403fda1"}, - {file = "coverage-7.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5c913b556a116b8d5f6ef834038ba983834d887d82187c8f73dec21049abd65c"}, - {file = "coverage-7.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1981f785239e4e39e6444c63a98da3a1db8e971cb9ceb50a945ba6296b43f312"}, - {file = "coverage-7.3.2-cp311-cp311-win32.whl", hash = "sha256:43668cabd5ca8258f5954f27a3aaf78757e6acf13c17604d89648ecc0cc66640"}, - {file = "coverage-7.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10c39c0452bf6e694511c901426d6b5ac005acc0f78ff265dbe36bf81f808a2"}, - {file = "coverage-7.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:4cbae1051ab791debecc4a5dcc4a1ff45fc27b91b9aee165c8a27514dd160836"}, - {file = "coverage-7.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:12d15ab5833a997716d76f2ac1e4b4d536814fc213c85ca72756c19e5a6b3d63"}, - {file = "coverage-7.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c7bba973ebee5e56fe9251300c00f1579652587a9f4a5ed8404b15a0471f216"}, - {file = "coverage-7.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fe494faa90ce6381770746077243231e0b83ff3f17069d748f645617cefe19d4"}, - {file = "coverage-7.3.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6e9589bd04d0461a417562649522575d8752904d35c12907d8c9dfeba588faf"}, - {file = "coverage-7.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d51ac2a26f71da1b57f2dc81d0e108b6ab177e7d30e774db90675467c847bbdf"}, - {file = "coverage-7.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:99b89d9f76070237975b315b3d5f4d6956ae354a4c92ac2388a5695516e47c84"}, - {file = "coverage-7.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:fa28e909776dc69efb6ed975a63691bc8172b64ff357e663a1bb06ff3c9b589a"}, - {file = "coverage-7.3.2-cp312-cp312-win32.whl", hash = "sha256:289fe43bf45a575e3ab10b26d7b6f2ddb9ee2dba447499f5401cfb5ecb8196bb"}, - {file = "coverage-7.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:7dbc3ed60e8659bc59b6b304b43ff9c3ed858da2839c78b804973f613d3e92ed"}, - {file = "coverage-7.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f94b734214ea6a36fe16e96a70d941af80ff3bfd716c141300d95ebc85339738"}, - {file = "coverage-7.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:af3d828d2c1cbae52d34bdbb22fcd94d1ce715d95f1a012354a75e5913f1bda2"}, - {file = "coverage-7.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:630b13e3036e13c7adc480ca42fa7afc2a5d938081d28e20903cf7fd687872e2"}, - {file = "coverage-7.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c9eacf273e885b02a0273bb3a2170f30e2d53a6d53b72dbe02d6701b5296101c"}, - {file = "coverage-7.3.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8f17966e861ff97305e0801134e69db33b143bbfb36436efb9cfff6ec7b2fd9"}, - {file = "coverage-7.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b4275802d16882cf9c8b3d057a0839acb07ee9379fa2749eca54efbce1535b82"}, - {file = "coverage-7.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:72c0cfa5250f483181e677ebc97133ea1ab3eb68645e494775deb6a7f6f83901"}, - {file = "coverage-7.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cb536f0dcd14149425996821a168f6e269d7dcd2c273a8bff8201e79f5104e76"}, - {file = "coverage-7.3.2-cp38-cp38-win32.whl", hash = "sha256:307adb8bd3abe389a471e649038a71b4eb13bfd6b7dd9a129fa856f5c695cf92"}, - {file = "coverage-7.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:88ed2c30a49ea81ea3b7f172e0269c182a44c236eb394718f976239892c0a27a"}, - {file = "coverage-7.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b631c92dfe601adf8f5ebc7fc13ced6bb6e9609b19d9a8cd59fa47c4186ad1ce"}, - {file = "coverage-7.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d3d9df4051c4a7d13036524b66ecf7a7537d14c18a384043f30a303b146164e9"}, - {file = "coverage-7.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f7363d3b6a1119ef05015959ca24a9afc0ea8a02c687fe7e2d557705375c01f"}, - {file = "coverage-7.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f11cc3c967a09d3695d2a6f03fb3e6236622b93be7a4b5dc09166a861be6d25"}, - {file = "coverage-7.3.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:149de1d2401ae4655c436a3dced6dd153f4c3309f599c3d4bd97ab172eaf02d9"}, - {file = "coverage-7.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:3a4006916aa6fee7cd38db3bfc95aa9c54ebb4ffbfc47c677c8bba949ceba0a6"}, - {file = "coverage-7.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9028a3871280110d6e1aa2df1afd5ef003bab5fb1ef421d6dc748ae1c8ef2ebc"}, - {file = "coverage-7.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9f805d62aec8eb92bab5b61c0f07329275b6f41c97d80e847b03eb894f38d083"}, - {file = "coverage-7.3.2-cp39-cp39-win32.whl", hash = "sha256:d1c88ec1a7ff4ebca0219f5b1ef863451d828cccf889c173e1253aa84b1e07ce"}, - {file = "coverage-7.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:b4767da59464bb593c07afceaddea61b154136300881844768037fd5e859353f"}, - {file = "coverage-7.3.2-pp38.pp39.pp310-none-any.whl", hash = "sha256:ae97af89f0fbf373400970c0a21eef5aa941ffeed90aee43650b81f7d7f47637"}, - {file = "coverage-7.3.2.tar.gz", hash = "sha256:be32ad29341b0170e795ca590e1c07e81fc061cb5b10c74ce7203491484404ef"}, + {file = "coverage-7.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:36b0ea8ab20d6a7564e89cb6135920bc9188fb5f1f7152e94e8300b7b189441a"}, + {file = "coverage-7.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0676cd0ba581e514b7f726495ea75aba3eb20899d824636c6f59b0ed2f88c471"}, + {file = "coverage-7.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0ca5c71a5a1765a0f8f88022c52b6b8be740e512980362f7fdbb03725a0d6b9"}, + {file = "coverage-7.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7c97726520f784239f6c62506bc70e48d01ae71e9da128259d61ca5e9788516"}, + {file = "coverage-7.4.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:815ac2d0f3398a14286dc2cea223a6f338109f9ecf39a71160cd1628786bc6f5"}, + {file = "coverage-7.4.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:80b5ee39b7f0131ebec7968baa9b2309eddb35b8403d1869e08f024efd883566"}, + {file = "coverage-7.4.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:5b2ccb7548a0b65974860a78c9ffe1173cfb5877460e5a229238d985565574ae"}, + {file = "coverage-7.4.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:995ea5c48c4ebfd898eacb098164b3cc826ba273b3049e4a889658548e321b43"}, + {file = "coverage-7.4.0-cp310-cp310-win32.whl", hash = "sha256:79287fd95585ed36e83182794a57a46aeae0b64ca53929d1176db56aacc83451"}, + {file = "coverage-7.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:5b14b4f8760006bfdb6e08667af7bc2d8d9bfdb648351915315ea17645347137"}, + {file = "coverage-7.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:04387a4a6ecb330c1878907ce0dc04078ea72a869263e53c72a1ba5bbdf380ca"}, + {file = "coverage-7.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ea81d8f9691bb53f4fb4db603203029643caffc82bf998ab5b59ca05560f4c06"}, + {file = "coverage-7.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74775198b702868ec2d058cb92720a3c5a9177296f75bd97317c787daf711505"}, + {file = "coverage-7.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76f03940f9973bfaee8cfba70ac991825611b9aac047e5c80d499a44079ec0bc"}, + {file = "coverage-7.4.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:485e9f897cf4856a65a57c7f6ea3dc0d4e6c076c87311d4bc003f82cfe199d25"}, + {file = "coverage-7.4.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6ae8c9d301207e6856865867d762a4b6fd379c714fcc0607a84b92ee63feff70"}, + {file = "coverage-7.4.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:bf477c355274a72435ceb140dc42de0dc1e1e0bf6e97195be30487d8eaaf1a09"}, + {file = "coverage-7.4.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:83c2dda2666fe32332f8e87481eed056c8b4d163fe18ecc690b02802d36a4d26"}, + {file = "coverage-7.4.0-cp311-cp311-win32.whl", hash = "sha256:697d1317e5290a313ef0d369650cfee1a114abb6021fa239ca12b4849ebbd614"}, + {file = "coverage-7.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:26776ff6c711d9d835557ee453082025d871e30b3fd6c27fcef14733f67f0590"}, + {file = "coverage-7.4.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:13eaf476ec3e883fe3e5fe3707caeb88268a06284484a3daf8250259ef1ba143"}, + {file = "coverage-7.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:846f52f46e212affb5bcf131c952fb4075b55aae6b61adc9856222df89cbe3e2"}, + {file = "coverage-7.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:26f66da8695719ccf90e794ed567a1549bb2644a706b41e9f6eae6816b398c4a"}, + {file = "coverage-7.4.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:164fdcc3246c69a6526a59b744b62e303039a81e42cfbbdc171c91a8cc2f9446"}, + {file = "coverage-7.4.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:316543f71025a6565677d84bc4df2114e9b6a615aa39fb165d697dba06a54af9"}, + {file = "coverage-7.4.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bb1de682da0b824411e00a0d4da5a784ec6496b6850fdf8c865c1d68c0e318dd"}, + {file = "coverage-7.4.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:0e8d06778e8fbffccfe96331a3946237f87b1e1d359d7fbe8b06b96c95a5407a"}, + {file = "coverage-7.4.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a56de34db7b7ff77056a37aedded01b2b98b508227d2d0979d373a9b5d353daa"}, + {file = "coverage-7.4.0-cp312-cp312-win32.whl", hash = "sha256:51456e6fa099a8d9d91497202d9563a320513fcf59f33991b0661a4a6f2ad450"}, + {file = "coverage-7.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:cd3c1e4cb2ff0083758f09be0f77402e1bdf704adb7f89108007300a6da587d0"}, + {file = "coverage-7.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e9d1bf53c4c8de58d22e0e956a79a5b37f754ed1ffdbf1a260d9dcfa2d8a325e"}, + {file = "coverage-7.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:109f5985182b6b81fe33323ab4707011875198c41964f014579cf82cebf2bb85"}, + {file = "coverage-7.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cc9d4bc55de8003663ec94c2f215d12d42ceea128da8f0f4036235a119c88ac"}, + {file = "coverage-7.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cc6d65b21c219ec2072c1293c505cf36e4e913a3f936d80028993dd73c7906b1"}, + {file = "coverage-7.4.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a10a4920def78bbfff4eff8a05c51be03e42f1c3735be42d851f199144897ba"}, + {file = "coverage-7.4.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b8e99f06160602bc64da35158bb76c73522a4010f0649be44a4e167ff8555952"}, + {file = "coverage-7.4.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:7d360587e64d006402b7116623cebf9d48893329ef035278969fa3bbf75b697e"}, + {file = "coverage-7.4.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:29f3abe810930311c0b5d1a7140f6395369c3db1be68345638c33eec07535105"}, + {file = "coverage-7.4.0-cp38-cp38-win32.whl", hash = "sha256:5040148f4ec43644702e7b16ca864c5314ccb8ee0751ef617d49aa0e2d6bf4f2"}, + {file = "coverage-7.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:9864463c1c2f9cb3b5db2cf1ff475eed2f0b4285c2aaf4d357b69959941aa555"}, + {file = "coverage-7.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:936d38794044b26c99d3dd004d8af0035ac535b92090f7f2bb5aa9c8e2f5cd42"}, + {file = "coverage-7.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:799c8f873794a08cdf216aa5d0531c6a3747793b70c53f70e98259720a6fe2d7"}, + {file = "coverage-7.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e7defbb9737274023e2d7af02cac77043c86ce88a907c58f42b580a97d5bcca9"}, + {file = "coverage-7.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a1526d265743fb49363974b7aa8d5899ff64ee07df47dd8d3e37dcc0818f09ed"}, + {file = "coverage-7.4.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf635a52fc1ea401baf88843ae8708591aa4adff875e5c23220de43b1ccf575c"}, + {file = "coverage-7.4.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:756ded44f47f330666843b5781be126ab57bb57c22adbb07d83f6b519783b870"}, + {file = "coverage-7.4.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:0eb3c2f32dabe3a4aaf6441dde94f35687224dfd7eb2a7f47f3fd9428e421058"}, + {file = "coverage-7.4.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:bfd5db349d15c08311702611f3dccbef4b4e2ec148fcc636cf8739519b4a5c0f"}, + {file = "coverage-7.4.0-cp39-cp39-win32.whl", hash = "sha256:53d7d9158ee03956e0eadac38dfa1ec8068431ef8058fe6447043db1fb40d932"}, + {file = "coverage-7.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:cfd2a8b6b0d8e66e944d47cdec2f47c48fef2ba2f2dff5a9a75757f64172857e"}, + {file = "coverage-7.4.0-pp38.pp39.pp310-none-any.whl", hash = "sha256:c530833afc4707fe48524a44844493f36d8727f04dcce91fb978c414a8556cc6"}, + {file = "coverage-7.4.0.tar.gz", hash = "sha256:707c0f58cb1712b8809ece32b68996ee1e609f71bd14615bd8f87a1293cb610e"}, ] [package.dependencies] @@ -773,26 +818,26 @@ toml = ["tomli"] [[package]] name = "datasets" -version = "2.15.0" +version = "2.14.4" description = "HuggingFace community-driven open-source library of datasets" +category = "main" optional = false python-versions = ">=3.8.0" files = [ - {file = "datasets-2.15.0-py3-none-any.whl", hash = "sha256:6d658d23811393dfc982d026082e1650bdaaae28f6a86e651966cb072229a228"}, - {file = "datasets-2.15.0.tar.gz", hash = "sha256:a26d059370bd7503bd60e9337977199a13117a83f72fb61eda7e66f0c4d50b2b"}, + {file = "datasets-2.14.4-py3-none-any.whl", hash = "sha256:29336bd316a7d827ccd4da2236596279b20ca2ac78f64c04c9483da7cbc2459b"}, + {file = "datasets-2.14.4.tar.gz", hash = "sha256:ef29c2b5841de488cd343cfc26ab979bff77efa4d2285af51f1ad7db5c46a83b"}, ] [package.dependencies] aiohttp = "*" dill = ">=0.3.0,<0.3.8" -fsspec = {version = ">=2023.1.0,<=2023.10.0", extras = ["http"]} -huggingface-hub = ">=0.18.0" +fsspec = {version = ">=2021.11.1", extras = ["http"]} +huggingface-hub = ">=0.14.0,<1.0.0" multiprocess = "*" numpy = ">=1.17" packaging = "*" pandas = "*" pyarrow = ">=8.0.0" -pyarrow-hotfix = "*" pyyaml = ">=5.1" requests = ">=2.19.0" tqdm = ">=4.62.1" @@ -802,15 +847,15 @@ xxhash = "*" apache-beam = ["apache-beam (>=2.26.0,<2.44.0)"] audio = ["librosa", "soundfile (>=0.12.1)"] benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"] -dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "black (>=23.1,<24.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "pyyaml (>=5.3.1)", "rarfile (>=4.0)", "ruff (>=0.0.241)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] +dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "black (>=23.1,<24.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "pyyaml (>=5.3.1)", "rarfile (>=4.0)", "ruff (>=0.0.241)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"] -jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"] +jax = ["jax (>=0.2.8,!=0.3.2,<=0.3.25)", "jaxlib (>=0.1.65,<=0.3.25)"] metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"] quality = ["black (>=23.1,<24.0)", "pyyaml (>=5.3.1)", "ruff (>=0.0.241)"] s3 = ["s3fs"] tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"] tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] -tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] +tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] torch = ["torch"] vision = ["Pillow (>=6.2.1)"] @@ -818,6 +863,7 @@ vision = ["Pillow (>=6.2.1)"] name = "debugpy" version = "1.8.0" description = "An implementation of the Debug Adapter Protocol for Python" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -845,6 +891,7 @@ files = [ name = "decorator" version = "5.1.1" description = "Decorators for Humans" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -856,6 +903,7 @@ files = [ name = "defusedxml" version = "0.7.1" description = "XML bomb protection for Python stdlib modules" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -867,6 +915,7 @@ files = [ name = "dill" version = "0.3.7" description = "serialize all of Python" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -881,6 +930,7 @@ graph = ["objgraph (>=1.7.2)"] name = "docker-pycreds" version = "0.4.0" description = "Python bindings for the docker credentials store API" +category = "main" optional = false python-versions = "*" files = [ @@ -895,6 +945,7 @@ six = ">=1.4.0" name = "docutils" version = "0.19" description = "Docutils -- Python Documentation Utilities" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -906,6 +957,7 @@ files = [ name = "einops" version = "0.7.0" description = "A new flavour of deep learning operations" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -917,6 +969,7 @@ files = [ name = "exceptiongroup" version = "1.2.0" description = "Backport of PEP 654 (exception groups)" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -931,6 +984,7 @@ test = ["pytest (>=6)"] name = "executing" version = "2.0.1" description = "Get the currently executing AST node of a frame, and other information" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -945,6 +999,7 @@ tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipyth name = "fancy-einsum" version = "0.0.3" description = "Drop-in replacement for torch/numpy einsum, with descriptive variable names in equations" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -954,13 +1009,14 @@ files = [ [[package]] name = "fastjsonschema" -version = "2.19.0" +version = "2.19.1" description = "Fastest Python implementation of JSON schema" +category = "dev" optional = false python-versions = "*" files = [ - {file = "fastjsonschema-2.19.0-py3-none-any.whl", hash = "sha256:b9fd1a2dd6971dbc7fee280a95bd199ae0dd9ce22beb91cc75e9c1c528a5170e"}, - {file = "fastjsonschema-2.19.0.tar.gz", hash = "sha256:e25df6647e1bc4a26070b700897b07b542ec898dd4f1f6ea013e7f6a88417225"}, + {file = "fastjsonschema-2.19.1-py3-none-any.whl", hash = "sha256:3672b47bc94178c9f23dbb654bf47440155d4db9df5f7bc47643315f9c405cd0"}, + {file = "fastjsonschema-2.19.1.tar.gz", hash = "sha256:e3126a94bdc4623d3de4485f8d468a12f02a67921315ddc87836d6e456dc789d"}, ] [package.extras] @@ -970,6 +1026,7 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc name = "filelock" version = "3.13.1" description = "A platform independent file lock." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -986,6 +1043,7 @@ typing = ["typing-extensions (>=4.8)"] name = "fqdn" version = "1.5.1" description = "Validates fully-qualified domain names against RFC 1123, so that they are acceptable to modern bowsers" +category = "dev" optional = false python-versions = ">=2.7, !=3.0, !=3.1, !=3.2, !=3.3, !=3.4, <4" files = [ @@ -995,83 +1053,101 @@ files = [ [[package]] name = "frozenlist" -version = "1.4.0" +version = "1.4.1" description = "A list-like structure which implements collections.abc.MutableSequence" +category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:764226ceef3125e53ea2cb275000e309c0aa5464d43bd72abd661e27fffc26ab"}, - {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d6484756b12f40003c6128bfcc3fa9f0d49a687e171186c2d85ec82e3758c559"}, - {file = "frozenlist-1.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9ac08e601308e41eb533f232dbf6b7e4cea762f9f84f6357136eed926c15d12c"}, - {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d081f13b095d74b67d550de04df1c756831f3b83dc9881c38985834387487f1b"}, - {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:71932b597f9895f011f47f17d6428252fc728ba2ae6024e13c3398a087c2cdea"}, - {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:981b9ab5a0a3178ff413bca62526bb784249421c24ad7381e39d67981be2c326"}, - {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e41f3de4df3e80de75845d3e743b3f1c4c8613c3997a912dbf0229fc61a8b963"}, - {file = "frozenlist-1.4.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6918d49b1f90821e93069682c06ffde41829c346c66b721e65a5c62b4bab0300"}, - {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0e5c8764c7829343d919cc2dfc587a8db01c4f70a4ebbc49abde5d4b158b007b"}, - {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8d0edd6b1c7fb94922bf569c9b092ee187a83f03fb1a63076e7774b60f9481a8"}, - {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e29cda763f752553fa14c68fb2195150bfab22b352572cb36c43c47bedba70eb"}, - {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:0c7c1b47859ee2cac3846fde1c1dc0f15da6cec5a0e5c72d101e0f83dcb67ff9"}, - {file = "frozenlist-1.4.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:901289d524fdd571be1c7be054f48b1f88ce8dddcbdf1ec698b27d4b8b9e5d62"}, - {file = "frozenlist-1.4.0-cp310-cp310-win32.whl", hash = "sha256:1a0848b52815006ea6596c395f87449f693dc419061cc21e970f139d466dc0a0"}, - {file = "frozenlist-1.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:b206646d176a007466358aa21d85cd8600a415c67c9bd15403336c331a10d956"}, - {file = "frozenlist-1.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:de343e75f40e972bae1ef6090267f8260c1446a1695e77096db6cfa25e759a95"}, - {file = "frozenlist-1.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ad2a9eb6d9839ae241701d0918f54c51365a51407fd80f6b8289e2dfca977cc3"}, - {file = "frozenlist-1.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bd7bd3b3830247580de99c99ea2a01416dfc3c34471ca1298bccabf86d0ff4dc"}, - {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bdf1847068c362f16b353163391210269e4f0569a3c166bc6a9f74ccbfc7e839"}, - {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:38461d02d66de17455072c9ba981d35f1d2a73024bee7790ac2f9e361ef1cd0c"}, - {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5a32087d720c608f42caed0ef36d2b3ea61a9d09ee59a5142d6070da9041b8f"}, - {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dd65632acaf0d47608190a71bfe46b209719bf2beb59507db08ccdbe712f969b"}, - {file = "frozenlist-1.4.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261b9f5d17cac914531331ff1b1d452125bf5daa05faf73b71d935485b0c510b"}, - {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b89ac9768b82205936771f8d2eb3ce88503b1556324c9f903e7156669f521472"}, - {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:008eb8b31b3ea6896da16c38c1b136cb9fec9e249e77f6211d479db79a4eaf01"}, - {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e74b0506fa5aa5598ac6a975a12aa8928cbb58e1f5ac8360792ef15de1aa848f"}, - {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:490132667476f6781b4c9458298b0c1cddf237488abd228b0b3650e5ecba7467"}, - {file = "frozenlist-1.4.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:76d4711f6f6d08551a7e9ef28c722f4a50dd0fc204c56b4bcd95c6cc05ce6fbb"}, - {file = "frozenlist-1.4.0-cp311-cp311-win32.whl", hash = "sha256:a02eb8ab2b8f200179b5f62b59757685ae9987996ae549ccf30f983f40602431"}, - {file = "frozenlist-1.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:515e1abc578dd3b275d6a5114030b1330ba044ffba03f94091842852f806f1c1"}, - {file = "frozenlist-1.4.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f0ed05f5079c708fe74bf9027e95125334b6978bf07fd5ab923e9e55e5fbb9d3"}, - {file = "frozenlist-1.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ca265542ca427bf97aed183c1676e2a9c66942e822b14dc6e5f42e038f92a503"}, - {file = "frozenlist-1.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:491e014f5c43656da08958808588cc6c016847b4360e327a62cb308c791bd2d9"}, - {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17ae5cd0f333f94f2e03aaf140bb762c64783935cc764ff9c82dff626089bebf"}, - {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e78fb68cf9c1a6aa4a9a12e960a5c9dfbdb89b3695197aa7064705662515de2"}, - {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5655a942f5f5d2c9ed93d72148226d75369b4f6952680211972a33e59b1dfdc"}, - {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c11b0746f5d946fecf750428a95f3e9ebe792c1ee3b1e96eeba145dc631a9672"}, - {file = "frozenlist-1.4.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e66d2a64d44d50d2543405fb183a21f76b3b5fd16f130f5c99187c3fb4e64919"}, - {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:88f7bc0fcca81f985f78dd0fa68d2c75abf8272b1f5c323ea4a01a4d7a614efc"}, - {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5833593c25ac59ede40ed4de6d67eb42928cca97f26feea219f21d0ed0959b79"}, - {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:fec520865f42e5c7f050c2a79038897b1c7d1595e907a9e08e3353293ffc948e"}, - {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:b826d97e4276750beca7c8f0f1a4938892697a6bcd8ec8217b3312dad6982781"}, - {file = "frozenlist-1.4.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ceb6ec0a10c65540421e20ebd29083c50e6d1143278746a4ef6bcf6153171eb8"}, - {file = "frozenlist-1.4.0-cp38-cp38-win32.whl", hash = "sha256:2b8bcf994563466db019fab287ff390fffbfdb4f905fc77bc1c1d604b1c689cc"}, - {file = "frozenlist-1.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:a6c8097e01886188e5be3e6b14e94ab365f384736aa1fca6a0b9e35bd4a30bc7"}, - {file = "frozenlist-1.4.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:6c38721585f285203e4b4132a352eb3daa19121a035f3182e08e437cface44bf"}, - {file = "frozenlist-1.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a0c6da9aee33ff0b1a451e867da0c1f47408112b3391dd43133838339e410963"}, - {file = "frozenlist-1.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93ea75c050c5bb3d98016b4ba2497851eadf0ac154d88a67d7a6816206f6fa7f"}, - {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f61e2dc5ad442c52b4887f1fdc112f97caeff4d9e6ebe78879364ac59f1663e1"}, - {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa384489fefeb62321b238e64c07ef48398fe80f9e1e6afeff22e140e0850eef"}, - {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:10ff5faaa22786315ef57097a279b833ecab1a0bfb07d604c9cbb1c4cdc2ed87"}, - {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:007df07a6e3eb3e33e9a1fe6a9db7af152bbd8a185f9aaa6ece10a3529e3e1c6"}, - {file = "frozenlist-1.4.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f4f399d28478d1f604c2ff9119907af9726aed73680e5ed1ca634d377abb087"}, - {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c5374b80521d3d3f2ec5572e05adc94601985cc526fb276d0c8574a6d749f1b3"}, - {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ce31ae3e19f3c902de379cf1323d90c649425b86de7bbdf82871b8a2a0615f3d"}, - {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7211ef110a9194b6042449431e08c4d80c0481e5891e58d429df5899690511c2"}, - {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:556de4430ce324c836789fa4560ca62d1591d2538b8ceb0b4f68fb7b2384a27a"}, - {file = "frozenlist-1.4.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7645a8e814a3ee34a89c4a372011dcd817964ce8cb273c8ed6119d706e9613e3"}, - {file = "frozenlist-1.4.0-cp39-cp39-win32.whl", hash = "sha256:19488c57c12d4e8095a922f328df3f179c820c212940a498623ed39160bc3c2f"}, - {file = "frozenlist-1.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:6221d84d463fb110bdd7619b69cb43878a11d51cbb9394ae3105d082d5199167"}, - {file = "frozenlist-1.4.0.tar.gz", hash = "sha256:09163bdf0b2907454042edb19f887c6d33806adc71fbd54afc14908bfdc22251"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29acab3f66f0f24674b7dc4736477bcd4bc3ad4b896f5f45379a67bce8b96868"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74fb4bee6880b529a0c6560885fce4dc95936920f9f20f53d99a213f7bf66776"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590344787a90ae57d62511dd7c736ed56b428f04cd8c161fcc5e7232c130c69a"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:068b63f23b17df8569b7fdca5517edef76171cf3897eb68beb01341131fbd2ad"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c849d495bf5154cd8da18a9eb15db127d4dba2968d88831aff6f0331ea9bd4c"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9750cc7fe1ae3b1611bb8cfc3f9ec11d532244235d75901fb6b8e42ce9229dfe"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0633c8d5337cb5c77acbccc6357ac49a1770b8c487e5b3505c57b949b4b82e98"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:27657df69e8801be6c3638054e202a135c7f299267f1a55ed3a598934f6c0d75"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:f9a3ea26252bd92f570600098783d1371354d89d5f6b7dfd87359d669f2109b5"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4f57dab5fe3407b6c0c1cc907ac98e8a189f9e418f3b6e54d65a718aaafe3950"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e02a0e11cf6597299b9f3bbd3f93d79217cb90cfd1411aec33848b13f5c656cc"}, + {file = "frozenlist-1.4.1-cp310-cp310-win32.whl", hash = "sha256:a828c57f00f729620a442881cc60e57cfcec6842ba38e1b19fd3e47ac0ff8dc1"}, + {file = "frozenlist-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:f56e2333dda1fe0f909e7cc59f021eba0d2307bc6f012a1ccf2beca6ba362439"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fde5bd59ab5357e3853313127f4d3565fc7dad314a74d7b5d43c22c6a5ed2ced"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:722e1124aec435320ae01ee3ac7bec11a5d47f25d0ed6328f2273d287bc3abb0"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2471c201b70d58a0f0c1f91261542a03d9a5e088ed3dc6c160d614c01649c106"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c757a9dd70d72b076d6f68efdbb9bc943665ae954dad2801b874c8c69e185068"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f146e0911cb2f1da549fc58fc7bcd2b836a44b79ef871980d605ec392ff6b0d2"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9c515e7914626b2a2e1e311794b4c35720a0be87af52b79ff8e1429fc25f19"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c302220494f5c1ebeb0912ea782bcd5e2f8308037b3c7553fad0e48ebad6ad82"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:442acde1e068288a4ba7acfe05f5f343e19fac87bfc96d89eb886b0363e977ec"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:1b280e6507ea8a4fa0c0a7150b4e526a8d113989e28eaaef946cc77ffd7efc0a"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:fe1a06da377e3a1062ae5fe0926e12b84eceb8a50b350ddca72dc85015873f74"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db9e724bebd621d9beca794f2a4ff1d26eed5965b004a97f1f1685a173b869c2"}, + {file = "frozenlist-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e774d53b1a477a67838a904131c4b0eef6b3d8a651f8b138b04f748fccfefe17"}, + {file = "frozenlist-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb3c2db03683b5767dedb5769b8a40ebb47d6f7f45b1b3e3b4b51ec8ad9d9825"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1979bc0aeb89b33b588c51c54ab0161791149f2461ea7c7c946d95d5f93b56ae"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cc7b01b3754ea68a62bd77ce6020afaffb44a590c2289089289363472d13aedb"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9c92be9fd329ac801cc420e08452b70e7aeab94ea4233a4804f0915c14eba9b"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3894db91f5a489fc8fa6a9991820f368f0b3cbdb9cd8849547ccfab3392d86"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba60bb19387e13597fb059f32cd4d59445d7b18b69a745b8f8e5db0346f33480"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aefbba5f69d42246543407ed2461db31006b0f76c4e32dfd6f42215a2c41d09"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780d3a35680ced9ce682fbcf4cb9c2bad3136eeff760ab33707b71db84664e3a"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9acbb16f06fe7f52f441bb6f413ebae6c37baa6ef9edd49cdd567216da8600cd"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23b701e65c7b36e4bf15546a89279bd4d8675faabc287d06bbcfac7d3c33e1e6"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3e0153a805a98f5ada7e09826255ba99fb4f7524bb81bf6b47fb702666484ae1"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:dd9b1baec094d91bf36ec729445f7769d0d0cf6b64d04d86e45baf89e2b9059b"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:1a4471094e146b6790f61b98616ab8e44f72661879cc63fa1049d13ef711e71e"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5667ed53d68d91920defdf4035d1cdaa3c3121dc0b113255124bcfada1cfa1b8"}, + {file = "frozenlist-1.4.1-cp312-cp312-win32.whl", hash = "sha256:beee944ae828747fd7cb216a70f120767fc9f4f00bacae8543c14a6831673f89"}, + {file = "frozenlist-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:64536573d0a2cb6e625cf309984e2d873979709f2cf22839bf2d61790b448ad5"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20b51fa3f588ff2fe658663db52a41a4f7aa6c04f6201449c6c7c476bd255c0d"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:410478a0c562d1a5bcc2f7ea448359fcb050ed48b3c6f6f4f18c313a9bdb1826"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c6321c9efe29975232da3bd0af0ad216800a47e93d763ce64f291917a381b8eb"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48f6a4533887e189dae092f1cf981f2e3885175f7a0f33c91fb5b7b682b6bab6"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6eb73fa5426ea69ee0e012fb59cdc76a15b1283d6e32e4f8dc4482ec67d1194d"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbeb989b5cc29e8daf7f976b421c220f1b8c731cbf22b9130d8815418ea45887"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32453c1de775c889eb4e22f1197fe3bdfe457d16476ea407472b9442e6295f7a"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693945278a31f2086d9bf3df0fe8254bbeaef1fe71e1351c3bd730aa7d31c41b"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1d0ce09d36d53bbbe566fe296965b23b961764c0bcf3ce2fa45f463745c04701"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3a670dc61eb0d0eb7080890c13de3066790f9049b47b0de04007090807c776b0"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:dca69045298ce5c11fd539682cff879cc1e664c245d1c64da929813e54241d11"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a06339f38e9ed3a64e4c4e43aec7f59084033647f908e4259d279a52d3757d09"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b7f2f9f912dca3934c1baec2e4585a674ef16fe00218d833856408c48d5beee7"}, + {file = "frozenlist-1.4.1-cp38-cp38-win32.whl", hash = "sha256:e7004be74cbb7d9f34553a5ce5fb08be14fb33bc86f332fb71cbe5216362a497"}, + {file = "frozenlist-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:5a7d70357e7cee13f470c7883a063aae5fe209a493c57d86eb7f5a6f910fae09"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bfa4a17e17ce9abf47a74ae02f32d014c5e9404b6d9ac7f729e01562bbee601e"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b7e3ed87d4138356775346e6845cccbe66cd9e207f3cd11d2f0b9fd13681359d"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c99169d4ff810155ca50b4da3b075cbde79752443117d89429595c2e8e37fed8"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edb678da49d9f72c9f6c609fbe41a5dfb9a9282f9e6a2253d5a91e0fc382d7c0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6db4667b187a6742b33afbbaf05a7bc551ffcf1ced0000a571aedbb4aa42fc7b"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55fdc093b5a3cb41d420884cdaf37a1e74c3c37a31f46e66286d9145d2063bd0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82e8211d69a4f4bc360ea22cd6555f8e61a1bd211d1d5d39d3d228b48c83a897"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89aa2c2eeb20957be2d950b85974b30a01a762f3308cd02bb15e1ad632e22dc7"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d3e0c25a2350080e9319724dede4f31f43a6c9779be48021a7f4ebde8b2d742"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7268252af60904bf52c26173cbadc3a071cece75f873705419c8681f24d3edea"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:0c250a29735d4f15321007fb02865f0e6b6a41a6b88f1f523ca1596ab5f50bd5"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:96ec70beabbd3b10e8bfe52616a13561e58fe84c0101dd031dc78f250d5128b9"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:23b2d7679b73fe0e5a4560b672a39f98dfc6f60df63823b0a9970525325b95f6"}, + {file = "frozenlist-1.4.1-cp39-cp39-win32.whl", hash = "sha256:a7496bfe1da7fb1a4e1cc23bb67c58fab69311cc7d32b5a99c2007b4b2a0e932"}, + {file = "frozenlist-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e6a20a581f9ce92d389a8c7d7c3dd47c81fd5d6e655c8dddf341e14aa48659d0"}, + {file = "frozenlist-1.4.1-py3-none-any.whl", hash = "sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7"}, + {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, ] [[package]] name = "fsspec" -version = "2023.10.0" +version = "2023.12.2" description = "File-system specification" +category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2023.10.0-py3-none-any.whl", hash = "sha256:346a8f024efeb749d2a5fca7ba8854474b1ff9af7c3faaf636a4548781136529"}, - {file = "fsspec-2023.10.0.tar.gz", hash = "sha256:330c66757591df346ad3091a53bd907e15348c2ba17d63fd54f5c39c4457d2a5"}, + {file = "fsspec-2023.12.2-py3-none-any.whl", hash = "sha256:d800d87f72189a745fa3d6b033b9dc4a34ad069f60ca60b943a63599f5501960"}, + {file = "fsspec-2023.12.2.tar.gz", hash = "sha256:8548d39e8810b59c38014934f6b31e57f40c1b20f911f4cc2b85389c7e9bf0cb"}, ] [package.dependencies] @@ -1106,6 +1182,7 @@ tqdm = ["tqdm"] name = "furo" version = "2023.3.27" description = "A clean customisable Sphinx documentation theme." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1123,6 +1200,7 @@ sphinx-basic-ng = "*" name = "gitdb" version = "4.0.11" description = "Git Object Database" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1137,6 +1215,7 @@ smmap = ">=3.0.1,<6" name = "gitpython" version = "3.1.40" description = "GitPython is a Python library used to interact with Git repositories" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1152,13 +1231,14 @@ test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre [[package]] name = "huggingface-hub" -version = "0.19.4" +version = "0.20.2" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +category = "main" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.19.4-py3-none-any.whl", hash = "sha256:dba013f779da16f14b606492828f3760600a1e1801432d09fe1c33e50b825bb5"}, - {file = "huggingface_hub-0.19.4.tar.gz", hash = "sha256:176a4fc355a851c17550e7619488f383189727eab209534d7cef2114dae77b22"}, + {file = "huggingface_hub-0.20.2-py3-none-any.whl", hash = "sha256:53752eda2239d30a470c307a61cf9adcf136bc77b0a734338c7d04941af560d8"}, + {file = "huggingface_hub-0.20.2.tar.gz", hash = "sha256:215c5fceff631030c7a3d19ba7b588921c908b3f21eef31d160ebc245b200ff6"}, ] [package.dependencies] @@ -1171,15 +1251,14 @@ tqdm = ">=4.42.1" typing-extensions = ">=3.7.4.3" [package.extras] -all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] cli = ["InquirerPy (==0.3.4)"] -dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] -docs = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "hf-doc-builder", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)", "watchdog"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"] quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"] tensorflow = ["graphviz", "pydot", "tensorflow"] -testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] torch = ["torch"] typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] @@ -1187,6 +1266,7 @@ typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "t name = "idna" version = "3.6" description = "Internationalized Domain Names in Applications (IDNA)" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -1198,6 +1278,7 @@ files = [ name = "imagesize" version = "1.4.1" description = "Getting image size from png/jpeg/jpeg2000/gif file" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -1207,13 +1288,14 @@ files = [ [[package]] name = "importlib-metadata" -version = "7.0.0" +version = "7.0.1" description = "Read metadata from Python packages" +category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-7.0.0-py3-none-any.whl", hash = "sha256:d97503976bb81f40a193d41ee6570868479c69d5068651eb039c40d850c59d67"}, - {file = "importlib_metadata-7.0.0.tar.gz", hash = "sha256:7fc841f8b8332803464e5dc1c63a2e59121f46ca186c0e2e182e80bf8c1319f7"}, + {file = "importlib_metadata-7.0.1-py3-none-any.whl", hash = "sha256:4805911c3a4ec7c3966410053e9ec6a1fecd629117df5adee56dfc9432a1081e"}, + {file = "importlib_metadata-7.0.1.tar.gz", hash = "sha256:f238736bb06590ae52ac1fab06a3a9ef1d8dce2b7a35b5ab329371d6c8f5d2cc"}, ] [package.dependencies] @@ -1228,6 +1310,7 @@ testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs name = "importlib-resources" version = "6.1.1" description = "Read resources from Python packages" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1246,6 +1329,7 @@ testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1255,13 +1339,14 @@ files = [ [[package]] name = "ipykernel" -version = "6.27.1" +version = "6.28.0" description = "IPython Kernel for Jupyter" +category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "ipykernel-6.27.1-py3-none-any.whl", hash = "sha256:dab88b47f112f9f7df62236511023c9bdeef67abc73af7c652e4ce4441601686"}, - {file = "ipykernel-6.27.1.tar.gz", hash = "sha256:7d5d594b6690654b4d299edba5e872dc17bb7396a8d0609c97cb7b8a1c605de6"}, + {file = "ipykernel-6.28.0-py3-none-any.whl", hash = "sha256:c6e9a9c63a7f4095c0a22a79f765f079f9ec7be4f2430a898ddea889e8665661"}, + {file = "ipykernel-6.28.0.tar.gz", hash = "sha256:69c11403d26de69df02225916f916b37ea4b9af417da0a8c827f84328d88e5f3"}, ] [package.dependencies] @@ -1270,12 +1355,12 @@ comm = ">=0.1.1" debugpy = ">=1.6.5" ipython = ">=7.23.1" jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" matplotlib-inline = ">=0.1" nest-asyncio = "*" packaging = "*" psutil = "*" -pyzmq = ">=20" +pyzmq = ">=24" tornado = ">=6.1" traitlets = ">=5.4.0" @@ -1290,6 +1375,7 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio" name = "ipython" version = "8.12.3" description = "IPython: Productive Interactive Computing" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1329,6 +1415,7 @@ test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.21)", "pa name = "ipywidgets" version = "8.1.1" description = "Jupyter interactive widgets" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1350,6 +1437,7 @@ test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"] name = "isoduration" version = "20.11.0" description = "Operations with ISO 8601 durations" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1364,6 +1452,7 @@ arrow = ">=0.15.0" name = "isort" version = "5.8.0" description = "A Python utility / library to sort Python imports." +category = "dev" optional = false python-versions = ">=3.6,<4.0" files = [ @@ -1380,6 +1469,7 @@ requirements-deprecated-finder = ["pip-api", "pipreqs"] name = "jaxtyping" version = "0.2.19" description = "Type annotations and runtime checking for shape and dtype of JAX arrays, and PyTrees." +category = "main" optional = false python-versions = "~=3.8" files = [ @@ -1396,6 +1486,7 @@ typing-extensions = ">=3.7.4.1" name = "jedi" version = "0.19.1" description = "An autocompletion tool for Python that can be used for text editors." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1415,6 +1506,7 @@ testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] name = "jinja2" version = "3.1.2" description = "A very fast and expressive template engine." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1432,6 +1524,7 @@ i18n = ["Babel (>=2.7)"] name = "json5" version = "0.9.14" description = "A Python implementation of the JSON5 data format." +category = "dev" optional = false python-versions = "*" files = [ @@ -1446,6 +1539,7 @@ dev = ["hypothesis"] name = "jsonpointer" version = "2.4" description = "Identify specific nodes in a JSON document (RFC 6901)" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ @@ -1457,6 +1551,7 @@ files = [ name = "jsonschema" version = "4.20.0" description = "An implementation of JSON Schema validation for Python" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1486,13 +1581,14 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- [[package]] name = "jsonschema-specifications" -version = "2023.11.2" +version = "2023.12.1" description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" +category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "jsonschema_specifications-2023.11.2-py3-none-any.whl", hash = "sha256:e74ba7c0a65e8cb49dc26837d6cfe576557084a8b423ed16a420984228104f93"}, - {file = "jsonschema_specifications-2023.11.2.tar.gz", hash = "sha256:9472fc4fea474cd74bea4a2b190daeccb5a9e4db2ea80efcf7a1b582fc9a81b8"}, + {file = "jsonschema_specifications-2023.12.1-py3-none-any.whl", hash = "sha256:87e4fdf3a94858b8a2ba2778d9ba57d8a9cafca7c7489c46ba0d30a8bc6a9c3c"}, + {file = "jsonschema_specifications-2023.12.1.tar.gz", hash = "sha256:48a76787b3e70f5ed53f1160d2b81f586e4ca6d1548c5de7085d1682674764cc"}, ] [package.dependencies] @@ -1503,6 +1599,7 @@ referencing = ">=0.31.0" name = "jupyter" version = "1.0.0" description = "Jupyter metapackage. Install all the Jupyter components in one go." +category = "dev" optional = false python-versions = "*" files = [ @@ -1523,6 +1620,7 @@ qtconsole = "*" name = "jupyter-client" version = "8.6.0" description = "Jupyter protocol implementation and client libraries" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1532,7 +1630,7 @@ files = [ [package.dependencies] importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" python-dateutil = ">=2.8.2" pyzmq = ">=23.0" tornado = ">=6.2" @@ -1546,6 +1644,7 @@ test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pyt name = "jupyter-console" version = "6.6.3" description = "Jupyter terminal console" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1557,7 +1656,7 @@ files = [ ipykernel = ">=6.14" ipython = "*" jupyter-client = ">=7.0.0" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" prompt-toolkit = ">=3.0.30" pygments = "*" pyzmq = ">=17" @@ -1568,13 +1667,14 @@ test = ["flaky", "pexpect", "pytest"] [[package]] name = "jupyter-core" -version = "5.5.0" +version = "5.7.1" description = "Jupyter core package. A base package on which Jupyter projects rely." +category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_core-5.5.0-py3-none-any.whl", hash = "sha256:e11e02cd8ae0a9de5c6c44abf5727df9f2581055afe00b22183f621ba3585805"}, - {file = "jupyter_core-5.5.0.tar.gz", hash = "sha256:880b86053bf298a8724994f95e99b99130659022a4f7f45f563084b6223861d3"}, + {file = "jupyter_core-5.7.1-py3-none-any.whl", hash = "sha256:c65c82126453a723a2804aa52409930434598fd9d35091d63dfb919d2b765bb7"}, + {file = "jupyter_core-5.7.1.tar.gz", hash = "sha256:de61a9d7fc71240f688b2fb5ab659fbb56979458dc66a71decd098e03c79e218"}, ] [package.dependencies] @@ -1590,6 +1690,7 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] name = "jupyter-events" version = "0.9.0" description = "Jupyter Event System library" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1615,6 +1716,7 @@ test = ["click", "pre-commit", "pytest (>=7.0)", "pytest-asyncio (>=0.19.0)", "p name = "jupyter-lsp" version = "2.2.1" description = "Multi-Language Server WebSocket proxy for Jupyter Notebook/Lab server" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1628,13 +1730,14 @@ jupyter-server = ">=1.1.2" [[package]] name = "jupyter-server" -version = "2.12.1" +version = "2.12.3" description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications." +category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_server-2.12.1-py3-none-any.whl", hash = "sha256:fd030dd7be1ca572e4598203f718df6630c12bd28a599d7f1791c4d7938e1010"}, - {file = "jupyter_server-2.12.1.tar.gz", hash = "sha256:dc77b7dcc5fc0547acba2b2844f01798008667201eea27c6319ff9257d700a6d"}, + {file = "jupyter_server-2.12.3-py3-none-any.whl", hash = "sha256:6f85310ea5e6068568a521f079fba99d8d17e4884dd1d602ab0f43b3115204a8"}, + {file = "jupyter_server-2.12.3.tar.gz", hash = "sha256:a1d2d51e497b1a6256c48b6940b0dd49b2553981baf1690077c37792f1fa23a1"}, ] [package.dependencies] @@ -1642,7 +1745,7 @@ anyio = ">=3.1.0" argon2-cffi = "*" jinja2 = "*" jupyter-client = ">=7.4.4" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" jupyter-events = ">=0.9.0" jupyter-server-terminals = "*" nbconvert = ">=6.4.4" @@ -1664,13 +1767,14 @@ test = ["flaky", "ipykernel", "pre-commit", "pytest (>=7.0)", "pytest-console-sc [[package]] name = "jupyter-server-terminals" -version = "0.4.4" +version = "0.5.1" description = "A Jupyter Server Extension Providing Terminals." +category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_server_terminals-0.4.4-py3-none-any.whl", hash = "sha256:75779164661cec02a8758a5311e18bb8eb70c4e86c6b699403100f1585a12a36"}, - {file = "jupyter_server_terminals-0.4.4.tar.gz", hash = "sha256:57ab779797c25a7ba68e97bcfb5d7740f2b5e8a83b5e8102b10438041a7eac5d"}, + {file = "jupyter_server_terminals-0.5.1-py3-none-any.whl", hash = "sha256:5e63e947ddd97bb2832db5ef837a258d9ccd4192cd608c1270850ad947ae5dd7"}, + {file = "jupyter_server_terminals-0.5.1.tar.gz", hash = "sha256:16d3be9cf48be6a1f943f3a6c93c033be259cf4779184c66421709cf63dccfea"}, ] [package.dependencies] @@ -1678,18 +1782,19 @@ pywinpty = {version = ">=2.0.3", markers = "os_name == \"nt\""} terminado = ">=0.8.3" [package.extras] -docs = ["jinja2", "jupyter-server", "mistune (<3.0)", "myst-parser", "nbformat", "packaging", "pydata-sphinx-theme", "sphinxcontrib-github-alt", "sphinxcontrib-openapi", "sphinxcontrib-spelling", "sphinxemoji", "tornado"] -test = ["coverage", "jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-cov", "pytest-jupyter[server] (>=0.5.3)", "pytest-timeout"] +docs = ["jinja2", "jupyter-server", "mistune (<4.0)", "myst-parser", "nbformat", "packaging", "pydata-sphinx-theme", "sphinxcontrib-github-alt", "sphinxcontrib-openapi", "sphinxcontrib-spelling", "sphinxemoji", "tornado"] +test = ["jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-jupyter[server] (>=0.5.3)", "pytest-timeout"] [[package]] name = "jupyterlab" -version = "4.0.9" +version = "4.0.10" description = "JupyterLab computational environment" +category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "jupyterlab-4.0.9-py3-none-any.whl", hash = "sha256:9f6f8e36d543fdbcc3df961a1d6a3f524b4a4001be0327a398f68fa4e534107c"}, - {file = "jupyterlab-4.0.9.tar.gz", hash = "sha256:9ebada41d52651f623c0c9f069ddb8a21d6848e4c887d8e5ddc0613166ed5c0b"}, + {file = "jupyterlab-4.0.10-py3-none-any.whl", hash = "sha256:fe010ad9e37017488b468632ef2ead255fc7c671c5b64d9ca13e1f7b7e665c37"}, + {file = "jupyterlab-4.0.10.tar.gz", hash = "sha256:46177eb8ede70dc73be922ac99f8ef943bdc2dfbc6a31b353c4bde848a35dee1"}, ] [package.dependencies] @@ -1709,7 +1814,7 @@ tornado = ">=6.2.0" traitlets = "*" [package.extras] -dev = ["black[jupyter] (==23.10.1)", "build", "bump2version", "coverage", "hatch", "pre-commit", "pytest-cov", "ruff (==0.1.4)"] +dev = ["build", "bump2version", "coverage", "hatch", "pre-commit", "pytest-cov", "ruff (==0.1.6)"] docs = ["jsx-lexer", "myst-parser", "pydata-sphinx-theme (>=0.13.0)", "pytest", "pytest-check-links", "pytest-tornasync", "sphinx (>=1.8,<7.2.0)", "sphinx-copybutton"] docs-screenshots = ["altair (==5.0.1)", "ipython (==8.14.0)", "ipywidgets (==8.0.6)", "jupyterlab-geojson (==3.4.0)", "jupyterlab-language-pack-zh-cn (==4.0.post0)", "matplotlib (==3.7.1)", "nbconvert (>=7.0.0)", "pandas (==2.0.2)", "scipy (==1.10.1)", "vega-datasets (==0.9.0)"] test = ["coverage", "pytest (>=7.0)", "pytest-check-links (>=0.7)", "pytest-console-scripts", "pytest-cov", "pytest-jupyter (>=0.5.3)", "pytest-timeout", "pytest-tornasync", "requests", "requests-cache", "virtualenv"] @@ -1718,6 +1823,7 @@ test = ["coverage", "pytest (>=7.0)", "pytest-check-links (>=0.7)", "pytest-cons name = "jupyterlab-pygments" version = "0.3.0" description = "Pygments theme using JupyterLab CSS variables" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1729,6 +1835,7 @@ files = [ name = "jupyterlab-server" version = "2.25.2" description = "A set of server components for JupyterLab and JupyterLab like applications." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1755,6 +1862,7 @@ test = ["hatch", "ipykernel", "openapi-core (>=0.18.0,<0.19.0)", "openapi-spec-v name = "jupyterlab-widgets" version = "3.0.9" description = "Jupyter interactive widgets for JupyterLab" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1766,6 +1874,7 @@ files = [ name = "libcst" version = "1.1.0" description = "A concrete syntax tree with AST-like properties for Python 3.5, 3.6, 3.7, 3.8, 3.9, and 3.10 programs." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1814,6 +1923,7 @@ dev = ["Sphinx (>=5.1.1)", "black (==23.9.1)", "build (>=0.10.0)", "coverage (>= name = "livereload" version = "2.6.3" description = "Python LiveReload is an awesome tool for web developers" +category = "dev" optional = false python-versions = "*" files = [ @@ -1829,6 +1939,7 @@ tornado = {version = "*", markers = "python_version > \"2.7\""} name = "markdown-it-py" version = "2.2.0" description = "Python port of markdown-it. Markdown parsing, done right!" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1853,6 +1964,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] name = "markupsafe" version = "2.1.3" description = "Safely add untrusted strings to HTML/XML markup." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1922,6 +2034,7 @@ files = [ name = "matplotlib-inline" version = "0.1.6" description = "Inline Matplotlib backend for Jupyter" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -1936,6 +2049,7 @@ traitlets = "*" name = "mdit-py-plugins" version = "0.3.5" description = "Collection of plugins for markdown-it-py" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1955,6 +2069,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] name = "mdurl" version = "0.1.2" description = "Markdown URL utilities" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1966,6 +2081,7 @@ files = [ name = "mistune" version = "3.0.2" description = "A sane and fast Markdown parser with useful plugins and renderers" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1977,6 +2093,7 @@ files = [ name = "mpmath" version = "1.3.0" description = "Python library for arbitrary-precision floating-point arithmetic" +category = "main" optional = false python-versions = "*" files = [ @@ -1994,6 +2111,7 @@ tests = ["pytest (>=4.6)"] name = "multidict" version = "6.0.4" description = "multidict implementation" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2077,6 +2195,7 @@ files = [ name = "multiprocess" version = "0.70.15" description = "better multiprocessing and multithreading in Python" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2103,38 +2222,39 @@ dill = ">=0.3.7" [[package]] name = "mypy" -version = "1.7.1" +version = "1.8.0" description = "Optional static typing for Python" +category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "mypy-1.7.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:12cce78e329838d70a204293e7b29af9faa3ab14899aec397798a4b41be7f340"}, - {file = "mypy-1.7.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1484b8fa2c10adf4474f016e09d7a159602f3239075c7bf9f1627f5acf40ad49"}, - {file = "mypy-1.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31902408f4bf54108bbfb2e35369877c01c95adc6192958684473658c322c8a5"}, - {file = "mypy-1.7.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f2c2521a8e4d6d769e3234350ba7b65ff5d527137cdcde13ff4d99114b0c8e7d"}, - {file = "mypy-1.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:fcd2572dd4519e8a6642b733cd3a8cfc1ef94bafd0c1ceed9c94fe736cb65b6a"}, - {file = "mypy-1.7.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4b901927f16224d0d143b925ce9a4e6b3a758010673eeded9b748f250cf4e8f7"}, - {file = "mypy-1.7.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2f7f6985d05a4e3ce8255396df363046c28bea790e40617654e91ed580ca7c51"}, - {file = "mypy-1.7.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:944bdc21ebd620eafefc090cdf83158393ec2b1391578359776c00de00e8907a"}, - {file = "mypy-1.7.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9c7ac372232c928fff0645d85f273a726970c014749b924ce5710d7d89763a28"}, - {file = "mypy-1.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:f6efc9bd72258f89a3816e3a98c09d36f079c223aa345c659622f056b760ab42"}, - {file = "mypy-1.7.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6dbdec441c60699288adf051f51a5d512b0d818526d1dcfff5a41f8cd8b4aaf1"}, - {file = "mypy-1.7.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4fc3d14ee80cd22367caaaf6e014494415bf440980a3045bf5045b525680ac33"}, - {file = "mypy-1.7.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c6e4464ed5f01dc44dc9821caf67b60a4e5c3b04278286a85c067010653a0eb"}, - {file = "mypy-1.7.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:d9b338c19fa2412f76e17525c1b4f2c687a55b156320acb588df79f2e6fa9fea"}, - {file = "mypy-1.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:204e0d6de5fd2317394a4eff62065614c4892d5a4d1a7ee55b765d7a3d9e3f82"}, - {file = "mypy-1.7.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:84860e06ba363d9c0eeabd45ac0fde4b903ad7aa4f93cd8b648385a888e23200"}, - {file = "mypy-1.7.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8c5091ebd294f7628eb25ea554852a52058ac81472c921150e3a61cdd68f75a7"}, - {file = "mypy-1.7.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40716d1f821b89838589e5b3106ebbc23636ffdef5abc31f7cd0266db936067e"}, - {file = "mypy-1.7.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5cf3f0c5ac72139797953bd50bc6c95ac13075e62dbfcc923571180bebb662e9"}, - {file = "mypy-1.7.1-cp38-cp38-win_amd64.whl", hash = "sha256:78e25b2fd6cbb55ddfb8058417df193f0129cad5f4ee75d1502248e588d9e0d7"}, - {file = "mypy-1.7.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:75c4d2a6effd015786c87774e04331b6da863fc3fc4e8adfc3b40aa55ab516fe"}, - {file = "mypy-1.7.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2643d145af5292ee956aa0a83c2ce1038a3bdb26e033dadeb2f7066fb0c9abce"}, - {file = "mypy-1.7.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75aa828610b67462ffe3057d4d8a4112105ed211596b750b53cbfe182f44777a"}, - {file = "mypy-1.7.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ee5d62d28b854eb61889cde4e1dbc10fbaa5560cb39780c3995f6737f7e82120"}, - {file = "mypy-1.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:72cf32ce7dd3562373f78bd751f73c96cfb441de147cc2448a92c1a308bd0ca6"}, - {file = "mypy-1.7.1-py3-none-any.whl", hash = "sha256:f7c5d642db47376a0cc130f0de6d055056e010debdaf0707cd2b0fc7e7ef30ea"}, - {file = "mypy-1.7.1.tar.gz", hash = "sha256:fcb6d9afb1b6208b4c712af0dafdc650f518836065df0d4fb1d800f5d6773db2"}, + {file = "mypy-1.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:485a8942f671120f76afffff70f259e1cd0f0cfe08f81c05d8816d958d4577d3"}, + {file = "mypy-1.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:df9824ac11deaf007443e7ed2a4a26bebff98d2bc43c6da21b2b64185da011c4"}, + {file = "mypy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2afecd6354bbfb6e0160f4e4ad9ba6e4e003b767dd80d85516e71f2e955ab50d"}, + {file = "mypy-1.8.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8963b83d53ee733a6e4196954502b33567ad07dfd74851f32be18eb932fb1cb9"}, + {file = "mypy-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:e46f44b54ebddbeedbd3d5b289a893219065ef805d95094d16a0af6630f5d410"}, + {file = "mypy-1.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:855fe27b80375e5c5878492f0729540db47b186509c98dae341254c8f45f42ae"}, + {file = "mypy-1.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4c886c6cce2d070bd7df4ec4a05a13ee20c0aa60cb587e8d1265b6c03cf91da3"}, + {file = "mypy-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d19c413b3c07cbecf1f991e2221746b0d2a9410b59cb3f4fb9557f0365a1a817"}, + {file = "mypy-1.8.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9261ed810972061388918c83c3f5cd46079d875026ba97380f3e3978a72f503d"}, + {file = "mypy-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:51720c776d148bad2372ca21ca29256ed483aa9a4cdefefcef49006dff2a6835"}, + {file = "mypy-1.8.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:52825b01f5c4c1c4eb0db253ec09c7aa17e1a7304d247c48b6f3599ef40db8bd"}, + {file = "mypy-1.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f5ac9a4eeb1ec0f1ccdc6f326bcdb464de5f80eb07fb38b5ddd7b0de6bc61e55"}, + {file = "mypy-1.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afe3fe972c645b4632c563d3f3eff1cdca2fa058f730df2b93a35e3b0c538218"}, + {file = "mypy-1.8.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:42c6680d256ab35637ef88891c6bd02514ccb7e1122133ac96055ff458f93fc3"}, + {file = "mypy-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:720a5ca70e136b675af3af63db533c1c8c9181314d207568bbe79051f122669e"}, + {file = "mypy-1.8.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:028cf9f2cae89e202d7b6593cd98db6759379f17a319b5faf4f9978d7084cdc6"}, + {file = "mypy-1.8.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4e6d97288757e1ddba10dd9549ac27982e3e74a49d8d0179fc14d4365c7add66"}, + {file = "mypy-1.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f1478736fcebb90f97e40aff11a5f253af890c845ee0c850fe80aa060a267c6"}, + {file = "mypy-1.8.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:42419861b43e6962a649068a61f4a4839205a3ef525b858377a960b9e2de6e0d"}, + {file = "mypy-1.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:2b5b6c721bd4aabaadead3a5e6fa85c11c6c795e0c81a7215776ef8afc66de02"}, + {file = "mypy-1.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5c1538c38584029352878a0466f03a8ee7547d7bd9f641f57a0f3017a7c905b8"}, + {file = "mypy-1.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ef4be7baf08a203170f29e89d79064463b7fc7a0908b9d0d5114e8009c3a259"}, + {file = "mypy-1.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7178def594014aa6c35a8ff411cf37d682f428b3b5617ca79029d8ae72f5402b"}, + {file = "mypy-1.8.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ab3c84fa13c04aeeeabb2a7f67a25ef5d77ac9d6486ff33ded762ef353aa5592"}, + {file = "mypy-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:99b00bc72855812a60d253420d8a2eae839b0afa4938f09f4d2aa9bb4654263a"}, + {file = "mypy-1.8.0-py3-none-any.whl", hash = "sha256:538fd81bb5e430cc1381a443971c0475582ff9f434c16cd46d2c66763ce85d9d"}, + {file = "mypy-1.8.0.tar.gz", hash = "sha256:6ff8b244d7085a0b425b56d327b480c3b29cafbd2eff27316a004f9a7391ae07"}, ] [package.dependencies] @@ -2152,6 +2272,7 @@ reports = ["lxml"] name = "mypy-extensions" version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2163,6 +2284,7 @@ files = [ name = "myst-parser" version = "1.0.0" description = "An extended [CommonMark](https://spec.commonmark.org/) compliant parser," +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2189,6 +2311,7 @@ testing-docutils = ["pygments", "pytest (>=7,<8)", "pytest-param-files (>=0.3.4, name = "nbclient" version = "0.9.0" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." +category = "dev" optional = false python-versions = ">=3.8.0" files = [ @@ -2198,7 +2321,7 @@ files = [ [package.dependencies] jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" nbformat = ">=5.1" traitlets = ">=5.4" @@ -2209,13 +2332,14 @@ test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>= [[package]] name = "nbconvert" -version = "7.12.0" +version = "7.14.0" description = "Converting Jupyter Notebooks" +category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "nbconvert-7.12.0-py3-none-any.whl", hash = "sha256:5b6c848194d270cc55fb691169202620d7b52a12fec259508d142ecbe4219310"}, - {file = "nbconvert-7.12.0.tar.gz", hash = "sha256:b1564bd89f69a74cd6398b0362da94db07aafb991b7857216a766204a71612c0"}, + {file = "nbconvert-7.14.0-py3-none-any.whl", hash = "sha256:483dde47facdaa4875903d651305ad53cd76e2255ae3c61efe412a95f2d22a24"}, + {file = "nbconvert-7.14.0.tar.gz", hash = "sha256:92b9a44b63e5a7fb4f6fa0ef41261e35c16925046ccd1c04a5c8099bf100476e"}, ] [package.dependencies] @@ -2242,13 +2366,14 @@ docs = ["ipykernel", "ipython", "myst-parser", "nbsphinx (>=0.2.12)", "pydata-sp qtpdf = ["nbconvert[qtpng]"] qtpng = ["pyqtwebengine (>=5.15)"] serve = ["tornado (>=6.1)"] -test = ["flaky", "ipykernel", "ipywidgets (>=7)", "pytest"] +test = ["flaky", "ipykernel", "ipywidgets (>=7.5)", "pytest"] webpdf = ["playwright"] [[package]] name = "nbformat" version = "5.9.2" description = "The Jupyter Notebook format" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2270,6 +2395,7 @@ test = ["pep440", "pre-commit", "pytest", "testpath"] name = "nbsphinx" version = "0.9.3" description = "Jupyter Notebook Tools for Sphinx" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2289,6 +2415,7 @@ traitlets = ">=5" name = "nbval" version = "0.10.0" description = "A py.test plugin to validate Jupyter notebooks" +category = "dev" optional = false python-versions = ">=3.6, <4" files = [ @@ -2307,6 +2434,7 @@ pytest = ">=2.8" name = "nest-asyncio" version = "1.5.8" description = "Patch asyncio to allow nested event loops" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2318,6 +2446,7 @@ files = [ name = "networkx" version = "3.1" description = "Python package for creating and manipulating graphs and networks" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2336,6 +2465,7 @@ test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] name = "notebook" version = "7.0.6" description = "Jupyter Notebook - A web-based notebook environment for interactive computing" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2359,6 +2489,7 @@ test = ["importlib-resources (>=5.0)", "ipykernel", "jupyter-server[test] (>=2.4 name = "notebook-shim" version = "0.2.3" description = "A shim layer for notebook traits and config" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2376,6 +2507,7 @@ test = ["pytest", "pytest-console-scripts", "pytest-jupyter", "pytest-tornasync" name = "numpy" version = "1.24.4" description = "Fundamental package for array computing in Python" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2411,53 +2543,55 @@ files = [ [[package]] name = "numpy" -version = "1.26.2" +version = "1.26.3" description = "Fundamental package for array computing in Python" +category = "main" optional = false python-versions = ">=3.9" files = [ - {file = "numpy-1.26.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3703fc9258a4a122d17043e57b35e5ef1c5a5837c3db8be396c82e04c1cf9b0f"}, - {file = "numpy-1.26.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cc392fdcbd21d4be6ae1bb4475a03ce3b025cd49a9be5345d76d7585aea69440"}, - {file = "numpy-1.26.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36340109af8da8805d8851ef1d74761b3b88e81a9bd80b290bbfed61bd2b4f75"}, - {file = "numpy-1.26.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcc008217145b3d77abd3e4d5ef586e3bdfba8fe17940769f8aa09b99e856c00"}, - {file = "numpy-1.26.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:3ced40d4e9e18242f70dd02d739e44698df3dcb010d31f495ff00a31ef6014fe"}, - {file = "numpy-1.26.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b272d4cecc32c9e19911891446b72e986157e6a1809b7b56518b4f3755267523"}, - {file = "numpy-1.26.2-cp310-cp310-win32.whl", hash = "sha256:22f8fc02fdbc829e7a8c578dd8d2e15a9074b630d4da29cda483337e300e3ee9"}, - {file = "numpy-1.26.2-cp310-cp310-win_amd64.whl", hash = "sha256:26c9d33f8e8b846d5a65dd068c14e04018d05533b348d9eaeef6c1bd787f9919"}, - {file = "numpy-1.26.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b96e7b9c624ef3ae2ae0e04fa9b460f6b9f17ad8b4bec6d7756510f1f6c0c841"}, - {file = "numpy-1.26.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:aa18428111fb9a591d7a9cc1b48150097ba6a7e8299fb56bdf574df650e7d1f1"}, - {file = "numpy-1.26.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06fa1ed84aa60ea6ef9f91ba57b5ed963c3729534e6e54055fc151fad0423f0a"}, - {file = "numpy-1.26.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96ca5482c3dbdd051bcd1fce8034603d6ebfc125a7bd59f55b40d8f5d246832b"}, - {file = "numpy-1.26.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:854ab91a2906ef29dc3925a064fcd365c7b4da743f84b123002f6139bcb3f8a7"}, - {file = "numpy-1.26.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f43740ab089277d403aa07567be138fc2a89d4d9892d113b76153e0e412409f8"}, - {file = "numpy-1.26.2-cp311-cp311-win32.whl", hash = "sha256:a2bbc29fcb1771cd7b7425f98b05307776a6baf43035d3b80c4b0f29e9545186"}, - {file = "numpy-1.26.2-cp311-cp311-win_amd64.whl", hash = "sha256:2b3fca8a5b00184828d12b073af4d0fc5fdd94b1632c2477526f6bd7842d700d"}, - {file = "numpy-1.26.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a4cd6ed4a339c21f1d1b0fdf13426cb3b284555c27ac2f156dfdaaa7e16bfab0"}, - {file = "numpy-1.26.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5d5244aabd6ed7f312268b9247be47343a654ebea52a60f002dc70c769048e75"}, - {file = "numpy-1.26.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a3cdb4d9c70e6b8c0814239ead47da00934666f668426fc6e94cce869e13fd7"}, - {file = "numpy-1.26.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa317b2325f7aa0a9471663e6093c210cb2ae9c0ad824732b307d2c51983d5b6"}, - {file = "numpy-1.26.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:174a8880739c16c925799c018f3f55b8130c1f7c8e75ab0a6fa9d41cab092fd6"}, - {file = "numpy-1.26.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f79b231bf5c16b1f39c7f4875e1ded36abee1591e98742b05d8a0fb55d8a3eec"}, - {file = "numpy-1.26.2-cp312-cp312-win32.whl", hash = "sha256:4a06263321dfd3598cacb252f51e521a8cb4b6df471bb12a7ee5cbab20ea9167"}, - {file = "numpy-1.26.2-cp312-cp312-win_amd64.whl", hash = "sha256:b04f5dc6b3efdaab541f7857351aac359e6ae3c126e2edb376929bd3b7f92d7e"}, - {file = "numpy-1.26.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4eb8df4bf8d3d90d091e0146f6c28492b0be84da3e409ebef54349f71ed271ef"}, - {file = "numpy-1.26.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1a13860fdcd95de7cf58bd6f8bc5a5ef81c0b0625eb2c9a783948847abbef2c2"}, - {file = "numpy-1.26.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:64308ebc366a8ed63fd0bf426b6a9468060962f1a4339ab1074c228fa6ade8e3"}, - {file = "numpy-1.26.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baf8aab04a2c0e859da118f0b38617e5ee65d75b83795055fb66c0d5e9e9b818"}, - {file = "numpy-1.26.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d73a3abcac238250091b11caef9ad12413dab01669511779bc9b29261dd50210"}, - {file = "numpy-1.26.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:b361d369fc7e5e1714cf827b731ca32bff8d411212fccd29ad98ad622449cc36"}, - {file = "numpy-1.26.2-cp39-cp39-win32.whl", hash = "sha256:bd3f0091e845164a20bd5a326860c840fe2af79fa12e0469a12768a3ec578d80"}, - {file = "numpy-1.26.2-cp39-cp39-win_amd64.whl", hash = "sha256:2beef57fb031dcc0dc8fa4fe297a742027b954949cabb52a2a376c144e5e6060"}, - {file = "numpy-1.26.2-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:1cc3d5029a30fb5f06704ad6b23b35e11309491c999838c31f124fee32107c79"}, - {file = "numpy-1.26.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94cc3c222bb9fb5a12e334d0479b97bb2df446fbe622b470928f5284ffca3f8d"}, - {file = "numpy-1.26.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fe6b44fb8fcdf7eda4ef4461b97b3f63c466b27ab151bec2366db8b197387841"}, - {file = "numpy-1.26.2.tar.gz", hash = "sha256:f65738447676ab5777f11e6bbbdb8ce11b785e105f690bc45966574816b6d3ea"}, + {file = "numpy-1.26.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:806dd64230dbbfaca8a27faa64e2f414bf1c6622ab78cc4264f7f5f028fee3bf"}, + {file = "numpy-1.26.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02f98011ba4ab17f46f80f7f8f1c291ee7d855fcef0a5a98db80767a468c85cd"}, + {file = "numpy-1.26.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d45b3ec2faed4baca41c76617fcdcfa4f684ff7a151ce6fc78ad3b6e85af0a6"}, + {file = "numpy-1.26.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdd2b45bf079d9ad90377048e2747a0c82351989a2165821f0c96831b4a2a54b"}, + {file = "numpy-1.26.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:211ddd1e94817ed2d175b60b6374120244a4dd2287f4ece45d49228b4d529178"}, + {file = "numpy-1.26.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b1240f767f69d7c4c8a29adde2310b871153df9b26b5cb2b54a561ac85146485"}, + {file = "numpy-1.26.3-cp310-cp310-win32.whl", hash = "sha256:21a9484e75ad018974a2fdaa216524d64ed4212e418e0a551a2d83403b0531d3"}, + {file = "numpy-1.26.3-cp310-cp310-win_amd64.whl", hash = "sha256:9e1591f6ae98bcfac2a4bbf9221c0b92ab49762228f38287f6eeb5f3f55905ce"}, + {file = "numpy-1.26.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b831295e5472954104ecb46cd98c08b98b49c69fdb7040483aff799a755a7374"}, + {file = "numpy-1.26.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9e87562b91f68dd8b1c39149d0323b42e0082db7ddb8e934ab4c292094d575d6"}, + {file = "numpy-1.26.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c66d6fec467e8c0f975818c1796d25c53521124b7cfb760114be0abad53a0a2"}, + {file = "numpy-1.26.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f25e2811a9c932e43943a2615e65fc487a0b6b49218899e62e426e7f0a57eeda"}, + {file = "numpy-1.26.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:af36e0aa45e25c9f57bf684b1175e59ea05d9a7d3e8e87b7ae1a1da246f2767e"}, + {file = "numpy-1.26.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:51c7f1b344f302067b02e0f5b5d2daa9ed4a721cf49f070280ac202738ea7f00"}, + {file = "numpy-1.26.3-cp311-cp311-win32.whl", hash = "sha256:7ca4f24341df071877849eb2034948459ce3a07915c2734f1abb4018d9c49d7b"}, + {file = "numpy-1.26.3-cp311-cp311-win_amd64.whl", hash = "sha256:39763aee6dfdd4878032361b30b2b12593fb445ddb66bbac802e2113eb8a6ac4"}, + {file = "numpy-1.26.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a7081fd19a6d573e1a05e600c82a1c421011db7935ed0d5c483e9dd96b99cf13"}, + {file = "numpy-1.26.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:12c70ac274b32bc00c7f61b515126c9205323703abb99cd41836e8125ea0043e"}, + {file = "numpy-1.26.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f784e13e598e9594750b2ef6729bcd5a47f6cfe4a12cca13def35e06d8163e3"}, + {file = "numpy-1.26.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f24750ef94d56ce6e33e4019a8a4d68cfdb1ef661a52cdaee628a56d2437419"}, + {file = "numpy-1.26.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:77810ef29e0fb1d289d225cabb9ee6cf4d11978a00bb99f7f8ec2132a84e0166"}, + {file = "numpy-1.26.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8ed07a90f5450d99dad60d3799f9c03c6566709bd53b497eb9ccad9a55867f36"}, + {file = "numpy-1.26.3-cp312-cp312-win32.whl", hash = "sha256:f73497e8c38295aaa4741bdfa4fda1a5aedda5473074369eca10626835445511"}, + {file = "numpy-1.26.3-cp312-cp312-win_amd64.whl", hash = "sha256:da4b0c6c699a0ad73c810736303f7fbae483bcb012e38d7eb06a5e3b432c981b"}, + {file = "numpy-1.26.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1666f634cb3c80ccbd77ec97bc17337718f56d6658acf5d3b906ca03e90ce87f"}, + {file = "numpy-1.26.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:18c3319a7d39b2c6a9e3bb75aab2304ab79a811ac0168a671a62e6346c29b03f"}, + {file = "numpy-1.26.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b7e807d6888da0db6e7e75838444d62495e2b588b99e90dd80c3459594e857b"}, + {file = "numpy-1.26.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4d362e17bcb0011738c2d83e0a65ea8ce627057b2fdda37678f4374a382a137"}, + {file = "numpy-1.26.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b8c275f0ae90069496068c714387b4a0eba5d531aace269559ff2b43655edd58"}, + {file = "numpy-1.26.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:cc0743f0302b94f397a4a65a660d4cd24267439eb16493fb3caad2e4389bccbb"}, + {file = "numpy-1.26.3-cp39-cp39-win32.whl", hash = "sha256:9bc6d1a7f8cedd519c4b7b1156d98e051b726bf160715b769106661d567b3f03"}, + {file = "numpy-1.26.3-cp39-cp39-win_amd64.whl", hash = "sha256:867e3644e208c8922a3be26fc6bbf112a035f50f0a86497f98f228c50c607bb2"}, + {file = "numpy-1.26.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3c67423b3703f8fbd90f5adaa37f85b5794d3366948efe9a5190a5f3a83fc34e"}, + {file = "numpy-1.26.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46f47ee566d98849323f01b349d58f2557f02167ee301e5e28809a8c0e27a2d0"}, + {file = "numpy-1.26.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a8474703bffc65ca15853d5fd4d06b18138ae90c17c8d12169968e998e448bb5"}, + {file = "numpy-1.26.3.tar.gz", hash = "sha256:697df43e2b6310ecc9d95f05d5ef20eacc09c7c4ecc9da3f235d39e71b7da1e4"}, ] [[package]] name = "nvidia-cublas-cu12" version = "12.1.3.1" description = "CUBLAS native runtime libraries" +category = "main" optional = false python-versions = ">=3" files = [ @@ -2469,6 +2603,7 @@ files = [ name = "nvidia-cuda-cupti-cu12" version = "12.1.105" description = "CUDA profiling tools runtime libs." +category = "main" optional = false python-versions = ">=3" files = [ @@ -2480,6 +2615,7 @@ files = [ name = "nvidia-cuda-nvrtc-cu12" version = "12.1.105" description = "NVRTC native runtime libraries" +category = "main" optional = false python-versions = ">=3" files = [ @@ -2491,6 +2627,7 @@ files = [ name = "nvidia-cuda-runtime-cu12" version = "12.1.105" description = "CUDA Runtime native Libraries" +category = "main" optional = false python-versions = ">=3" files = [ @@ -2502,6 +2639,7 @@ files = [ name = "nvidia-cudnn-cu12" version = "8.9.2.26" description = "cuDNN runtime libraries" +category = "main" optional = false python-versions = ">=3" files = [ @@ -2515,6 +2653,7 @@ nvidia-cublas-cu12 = "*" name = "nvidia-cufft-cu12" version = "11.0.2.54" description = "CUFFT native runtime libraries" +category = "main" optional = false python-versions = ">=3" files = [ @@ -2526,6 +2665,7 @@ files = [ name = "nvidia-curand-cu12" version = "10.3.2.106" description = "CURAND native runtime libraries" +category = "main" optional = false python-versions = ">=3" files = [ @@ -2537,6 +2677,7 @@ files = [ name = "nvidia-cusolver-cu12" version = "11.4.5.107" description = "CUDA solver native runtime libraries" +category = "main" optional = false python-versions = ">=3" files = [ @@ -2553,6 +2694,7 @@ nvidia-nvjitlink-cu12 = "*" name = "nvidia-cusparse-cu12" version = "12.1.0.106" description = "CUSPARSE native runtime libraries" +category = "main" optional = false python-versions = ">=3" files = [ @@ -2567,6 +2709,7 @@ nvidia-nvjitlink-cu12 = "*" name = "nvidia-nccl-cu12" version = "2.18.1" description = "NVIDIA Collective Communication Library (NCCL) Runtime" +category = "main" optional = false python-versions = ">=3" files = [ @@ -2577,6 +2720,7 @@ files = [ name = "nvidia-nvjitlink-cu12" version = "12.3.101" description = "Nvidia JIT LTO Library" +category = "main" optional = false python-versions = ">=3" files = [ @@ -2588,6 +2732,7 @@ files = [ name = "nvidia-nvtx-cu12" version = "12.1.105" description = "NVIDIA Tools Extension" +category = "main" optional = false python-versions = ">=3" files = [ @@ -2599,6 +2744,7 @@ files = [ name = "overrides" version = "7.4.0" description = "A decorator to automatically detect mismatch when overriding a method." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2610,6 +2756,7 @@ files = [ name = "packaging" version = "23.2" description = "Core utilities for Python packages" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2621,6 +2768,7 @@ files = [ name = "pandas" version = "2.0.3" description = "Powerful data structures for data analysis, time series, and statistics" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2654,7 +2802,7 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, ] python-dateutil = ">=2.8.2" @@ -2688,6 +2836,7 @@ xml = ["lxml (>=4.6.3)"] name = "pandoc" version = "2.3" description = "Pandoc Documents for Python" +category = "dev" optional = false python-versions = "*" files = [ @@ -2702,6 +2851,7 @@ ply = "*" name = "pandocfilters" version = "1.5.0" description = "Utilities for writing pandoc filters in python" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2713,6 +2863,7 @@ files = [ name = "parso" version = "0.8.3" description = "A Python Parser" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2726,19 +2877,21 @@ testing = ["docopt", "pytest (<6.0.0)"] [[package]] name = "pathspec" -version = "0.12.0" +version = "0.12.1" description = "Utility library for gitignore style pattern matching of file paths." +category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "pathspec-0.12.0-py3-none-any.whl", hash = "sha256:f1f8a7eab698c357945c85ed79715e014612b8584faebe209dca4558e2b09513"}, - {file = "pathspec-0.12.0.tar.gz", hash = "sha256:c57e16065a97b7beb175f13c84d27cb05f7b7315741c2fbd5de541042f4ea6e1"}, + {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, + {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, ] [[package]] name = "pexpect" version = "4.9.0" description = "Pexpect allows easy control of interactive console applications." +category = "dev" optional = false python-versions = "*" files = [ @@ -2753,6 +2906,7 @@ ptyprocess = ">=0.5" name = "pickleshare" version = "0.7.5" description = "Tiny 'shelve'-like database with concurrency support" +category = "dev" optional = false python-versions = "*" files = [ @@ -2764,6 +2918,7 @@ files = [ name = "pkgutil-resolve-name" version = "1.3.10" description = "Resolve a name to an object." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2775,6 +2930,7 @@ files = [ name = "platformdirs" version = "4.1.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2790,6 +2946,7 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-co name = "plotly" version = "5.18.0" description = "An open-source, interactive data visualization library for Python" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2805,6 +2962,7 @@ tenacity = ">=6.2.0" name = "pluggy" version = "1.3.0" description = "plugin and hook calling mechanisms for python" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2820,6 +2978,7 @@ testing = ["pytest", "pytest-benchmark"] name = "plumbum" version = "1.8.2" description = "Plumbum: shell combinators library" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2839,6 +2998,7 @@ ssh = ["paramiko"] name = "ply" version = "3.11" description = "Python Lex & Yacc" +category = "dev" optional = false python-versions = "*" files = [ @@ -2850,6 +3010,7 @@ files = [ name = "pockets" version = "0.9.1" description = "A collection of helpful Python tools!" +category = "dev" optional = false python-versions = "*" files = [ @@ -2864,6 +3025,7 @@ six = ">=1.5.2" name = "prometheus-client" version = "0.19.0" description = "Python client for the Prometheus monitoring system." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2876,13 +3038,14 @@ twisted = ["twisted"] [[package]] name = "prompt-toolkit" -version = "3.0.41" +version = "3.0.43" description = "Library for building powerful interactive command lines in Python" +category = "dev" optional = false python-versions = ">=3.7.0" files = [ - {file = "prompt_toolkit-3.0.41-py3-none-any.whl", hash = "sha256:f36fe301fafb7470e86aaf90f036eef600a3210be4decf461a5b1ca8403d3cb2"}, - {file = "prompt_toolkit-3.0.41.tar.gz", hash = "sha256:941367d97fc815548822aa26c2a269fdc4eb21e9ec05fc5d447cf09bad5d75f0"}, + {file = "prompt_toolkit-3.0.43-py3-none-any.whl", hash = "sha256:a11a29cb3bf0a28a387fe5122cdb649816a957cd9261dcedf8c9f1fef33eacf6"}, + {file = "prompt_toolkit-3.0.43.tar.gz", hash = "sha256:3527b7af26106cbc65a040bcc84839a3566ec1b051bb0bfe953631e704b0ff7d"}, ] [package.dependencies] @@ -2892,6 +3055,7 @@ wcwidth = "*" name = "protobuf" version = "4.25.1" description = "" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2910,27 +3074,28 @@ files = [ [[package]] name = "psutil" -version = "5.9.6" +version = "5.9.7" description = "Cross-platform lib for process and system monitoring in Python." +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ - {file = "psutil-5.9.6-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:fb8a697f11b0f5994550555fcfe3e69799e5b060c8ecf9e2f75c69302cc35c0d"}, - {file = "psutil-5.9.6-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:91ecd2d9c00db9817a4b4192107cf6954addb5d9d67a969a4f436dbc9200f88c"}, - {file = "psutil-5.9.6-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:10e8c17b4f898d64b121149afb136c53ea8b68c7531155147867b7b1ac9e7e28"}, - {file = "psutil-5.9.6-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:18cd22c5db486f33998f37e2bb054cc62fd06646995285e02a51b1e08da97017"}, - {file = "psutil-5.9.6-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:ca2780f5e038379e520281e4c032dddd086906ddff9ef0d1b9dcf00710e5071c"}, - {file = "psutil-5.9.6-cp27-none-win32.whl", hash = "sha256:70cb3beb98bc3fd5ac9ac617a327af7e7f826373ee64c80efd4eb2856e5051e9"}, - {file = "psutil-5.9.6-cp27-none-win_amd64.whl", hash = "sha256:51dc3d54607c73148f63732c727856f5febec1c7c336f8f41fcbd6315cce76ac"}, - {file = "psutil-5.9.6-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c69596f9fc2f8acd574a12d5f8b7b1ba3765a641ea5d60fb4736bf3c08a8214a"}, - {file = "psutil-5.9.6-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:92e0cc43c524834af53e9d3369245e6cc3b130e78e26100d1f63cdb0abeb3d3c"}, - {file = "psutil-5.9.6-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:748c9dd2583ed86347ed65d0035f45fa8c851e8d90354c122ab72319b5f366f4"}, - {file = "psutil-5.9.6-cp36-cp36m-win32.whl", hash = "sha256:3ebf2158c16cc69db777e3c7decb3c0f43a7af94a60d72e87b2823aebac3d602"}, - {file = "psutil-5.9.6-cp36-cp36m-win_amd64.whl", hash = "sha256:ff18b8d1a784b810df0b0fff3bcb50ab941c3b8e2c8de5726f9c71c601c611aa"}, - {file = "psutil-5.9.6-cp37-abi3-win32.whl", hash = "sha256:a6f01f03bf1843280f4ad16f4bde26b817847b4c1a0db59bf6419807bc5ce05c"}, - {file = "psutil-5.9.6-cp37-abi3-win_amd64.whl", hash = "sha256:6e5fb8dc711a514da83098bc5234264e551ad980cec5f85dabf4d38ed6f15e9a"}, - {file = "psutil-5.9.6-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:daecbcbd29b289aac14ece28eca6a3e60aa361754cf6da3dfb20d4d32b6c7f57"}, - {file = "psutil-5.9.6.tar.gz", hash = "sha256:e4b92ddcd7dd4cdd3f900180ea1e104932c7bce234fb88976e2a3b296441225a"}, + {file = "psutil-5.9.7-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:0bd41bf2d1463dfa535942b2a8f0e958acf6607ac0be52265ab31f7923bcd5e6"}, + {file = "psutil-5.9.7-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:5794944462509e49d4d458f4dbfb92c47539e7d8d15c796f141f474010084056"}, + {file = "psutil-5.9.7-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:fe361f743cb3389b8efda21980d93eb55c1f1e3898269bc9a2a1d0bb7b1f6508"}, + {file = "psutil-5.9.7-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:e469990e28f1ad738f65a42dcfc17adaed9d0f325d55047593cb9033a0ab63df"}, + {file = "psutil-5.9.7-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:3c4747a3e2ead1589e647e64aad601981f01b68f9398ddf94d01e3dc0d1e57c7"}, + {file = "psutil-5.9.7-cp27-none-win32.whl", hash = "sha256:1d4bc4a0148fdd7fd8f38e0498639ae128e64538faa507df25a20f8f7fb2341c"}, + {file = "psutil-5.9.7-cp27-none-win_amd64.whl", hash = "sha256:4c03362e280d06bbbfcd52f29acd79c733e0af33d707c54255d21029b8b32ba6"}, + {file = "psutil-5.9.7-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ea36cc62e69a13ec52b2f625c27527f6e4479bca2b340b7a452af55b34fcbe2e"}, + {file = "psutil-5.9.7-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1132704b876e58d277168cd729d64750633d5ff0183acf5b3c986b8466cd0284"}, + {file = "psutil-5.9.7-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe8b7f07948f1304497ce4f4684881250cd859b16d06a1dc4d7941eeb6233bfe"}, + {file = "psutil-5.9.7-cp36-cp36m-win32.whl", hash = "sha256:b27f8fdb190c8c03914f908a4555159327d7481dac2f01008d483137ef3311a9"}, + {file = "psutil-5.9.7-cp36-cp36m-win_amd64.whl", hash = "sha256:44969859757f4d8f2a9bd5b76eba8c3099a2c8cf3992ff62144061e39ba8568e"}, + {file = "psutil-5.9.7-cp37-abi3-win32.whl", hash = "sha256:c727ca5a9b2dd5193b8644b9f0c883d54f1248310023b5ad3e92036c5e2ada68"}, + {file = "psutil-5.9.7-cp37-abi3-win_amd64.whl", hash = "sha256:f37f87e4d73b79e6c5e749440c3113b81d1ee7d26f21c19c47371ddea834f414"}, + {file = "psutil-5.9.7-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:032f4f2c909818c86cea4fe2cc407f1c0f0cde8e6c6d702b28b8ce0c0d143340"}, + {file = "psutil-5.9.7.tar.gz", hash = "sha256:3f02134e82cfb5d089fddf20bb2e03fd5cd52395321d1c8458a9e58500ff417c"}, ] [package.extras] @@ -2940,6 +3105,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] name = "ptyprocess" version = "0.7.0" description = "Run a subprocess in a pseudo terminal" +category = "dev" optional = false python-versions = "*" files = [ @@ -2951,6 +3117,7 @@ files = [ name = "pure-eval" version = "0.2.2" description = "Safely evaluate AST nodes without side effects" +category = "dev" optional = false python-versions = "*" files = [ @@ -2963,67 +3130,58 @@ tests = ["pytest"] [[package]] name = "pyarrow" -version = "14.0.1" +version = "14.0.2" description = "Python library for Apache Arrow" +category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "pyarrow-14.0.1-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:96d64e5ba7dceb519a955e5eeb5c9adcfd63f73a56aea4722e2cc81364fc567a"}, - {file = "pyarrow-14.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1a8ae88c0038d1bc362a682320112ee6774f006134cd5afc291591ee4bc06505"}, - {file = "pyarrow-14.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f6f053cb66dc24091f5511e5920e45c83107f954a21032feadc7b9e3a8e7851"}, - {file = "pyarrow-14.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:906b0dc25f2be12e95975722f1e60e162437023f490dbd80d0deb7375baf3171"}, - {file = "pyarrow-14.0.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:78d4a77a46a7de9388b653af1c4ce539350726cd9af62e0831e4f2bd0c95a2f4"}, - {file = "pyarrow-14.0.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:06ca79080ef89d6529bb8e5074d4b4f6086143b2520494fcb7cf8a99079cde93"}, - {file = "pyarrow-14.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:32542164d905002c42dff896efdac79b3bdd7291b1b74aa292fac8450d0e4dcd"}, - {file = "pyarrow-14.0.1-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:c7331b4ed3401b7ee56f22c980608cf273f0380f77d0f73dd3c185f78f5a6220"}, - {file = "pyarrow-14.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:922e8b49b88da8633d6cac0e1b5a690311b6758d6f5d7c2be71acb0f1e14cd61"}, - {file = "pyarrow-14.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58c889851ca33f992ea916b48b8540735055201b177cb0dcf0596a495a667b00"}, - {file = "pyarrow-14.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30d8494870d9916bb53b2a4384948491444741cb9a38253c590e21f836b01222"}, - {file = "pyarrow-14.0.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:be28e1a07f20391bb0b15ea03dcac3aade29fc773c5eb4bee2838e9b2cdde0cb"}, - {file = "pyarrow-14.0.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:981670b4ce0110d8dcb3246410a4aabf5714db5d8ea63b15686bce1c914b1f83"}, - {file = "pyarrow-14.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:4756a2b373a28f6166c42711240643fb8bd6322467e9aacabd26b488fa41ec23"}, - {file = "pyarrow-14.0.1-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:cf87e2cec65dd5cf1aa4aba918d523ef56ef95597b545bbaad01e6433851aa10"}, - {file = "pyarrow-14.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:470ae0194fbfdfbf4a6b65b4f9e0f6e1fa0ea5b90c1ee6b65b38aecee53508c8"}, - {file = "pyarrow-14.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6263cffd0c3721c1e348062997babdf0151301f7353010c9c9a8ed47448f82ab"}, - {file = "pyarrow-14.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a8089d7e77d1455d529dbd7cff08898bbb2666ee48bc4085203af1d826a33cc"}, - {file = "pyarrow-14.0.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:fada8396bc739d958d0b81d291cfd201126ed5e7913cb73de6bc606befc30226"}, - {file = "pyarrow-14.0.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:2a145dab9ed7849fc1101bf03bcdc69913547f10513fdf70fc3ab6c0a50c7eee"}, - {file = "pyarrow-14.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:05fe7994745b634c5fb16ce5717e39a1ac1fac3e2b0795232841660aa76647cd"}, - {file = "pyarrow-14.0.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:a8eeef015ae69d104c4c3117a6011e7e3ecd1abec79dc87fd2fac6e442f666ee"}, - {file = "pyarrow-14.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3c76807540989fe8fcd02285dd15e4f2a3da0b09d27781abec3adc265ddbeba1"}, - {file = "pyarrow-14.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:450e4605e3c20e558485f9161a79280a61c55efe585d51513c014de9ae8d393f"}, - {file = "pyarrow-14.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:323cbe60210173ffd7db78bfd50b80bdd792c4c9daca8843ef3cd70b186649db"}, - {file = "pyarrow-14.0.1-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0140c7e2b740e08c5a459439d87acd26b747fc408bde0a8806096ee0baaa0c15"}, - {file = "pyarrow-14.0.1-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:e592e482edd9f1ab32f18cd6a716c45b2c0f2403dc2af782f4e9674952e6dd27"}, - {file = "pyarrow-14.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:d264ad13605b61959f2ae7c1d25b1a5b8505b112715c961418c8396433f213ad"}, - {file = "pyarrow-14.0.1-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:01e44de9749cddc486169cb632f3c99962318e9dacac7778315a110f4bf8a450"}, - {file = "pyarrow-14.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d0351fecf0e26e152542bc164c22ea2a8e8c682726fce160ce4d459ea802d69c"}, - {file = "pyarrow-14.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33c1f6110c386464fd2e5e4ea3624466055bbe681ff185fd6c9daa98f30a3f9a"}, - {file = "pyarrow-14.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11e045dfa09855b6d3e7705a37c42e2dc2c71d608fab34d3c23df2e02df9aec3"}, - {file = "pyarrow-14.0.1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:097828b55321897db0e1dbfc606e3ff8101ae5725673498cbfa7754ee0da80e4"}, - {file = "pyarrow-14.0.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:1daab52050a1c48506c029e6fa0944a7b2436334d7e44221c16f6f1b2cc9c510"}, - {file = "pyarrow-14.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:3f6d5faf4f1b0d5a7f97be987cf9e9f8cd39902611e818fe134588ee99bf0283"}, - {file = "pyarrow-14.0.1.tar.gz", hash = "sha256:b8b3f4fe8d4ec15e1ef9b599b94683c5216adaed78d5cb4c606180546d1e2ee1"}, + {file = "pyarrow-14.0.2-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:ba9fe808596c5dbd08b3aeffe901e5f81095baaa28e7d5118e01354c64f22807"}, + {file = "pyarrow-14.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:22a768987a16bb46220cef490c56c671993fbee8fd0475febac0b3e16b00a10e"}, + {file = "pyarrow-14.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2dbba05e98f247f17e64303eb876f4a80fcd32f73c7e9ad975a83834d81f3fda"}, + {file = "pyarrow-14.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a898d134d00b1eca04998e9d286e19653f9d0fcb99587310cd10270907452a6b"}, + {file = "pyarrow-14.0.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:87e879323f256cb04267bb365add7208f302df942eb943c93a9dfeb8f44840b1"}, + {file = "pyarrow-14.0.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:76fc257559404ea5f1306ea9a3ff0541bf996ff3f7b9209fc517b5e83811fa8e"}, + {file = "pyarrow-14.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:b0c4a18e00f3a32398a7f31da47fefcd7a927545b396e1f15d0c85c2f2c778cd"}, + {file = "pyarrow-14.0.2-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:87482af32e5a0c0cce2d12eb3c039dd1d853bd905b04f3f953f147c7a196915b"}, + {file = "pyarrow-14.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:059bd8f12a70519e46cd64e1ba40e97eae55e0cbe1695edd95384653d7626b23"}, + {file = "pyarrow-14.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f16111f9ab27e60b391c5f6d197510e3ad6654e73857b4e394861fc79c37200"}, + {file = "pyarrow-14.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06ff1264fe4448e8d02073f5ce45a9f934c0f3db0a04460d0b01ff28befc3696"}, + {file = "pyarrow-14.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6dd4f4b472ccf4042f1eab77e6c8bce574543f54d2135c7e396f413046397d5a"}, + {file = "pyarrow-14.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:32356bfb58b36059773f49e4e214996888eeea3a08893e7dbde44753799b2a02"}, + {file = "pyarrow-14.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:52809ee69d4dbf2241c0e4366d949ba035cbcf48409bf404f071f624ed313a2b"}, + {file = "pyarrow-14.0.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:c87824a5ac52be210d32906c715f4ed7053d0180c1060ae3ff9b7e560f53f944"}, + {file = "pyarrow-14.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a25eb2421a58e861f6ca91f43339d215476f4fe159eca603c55950c14f378cc5"}, + {file = "pyarrow-14.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c1da70d668af5620b8ba0a23f229030a4cd6c5f24a616a146f30d2386fec422"}, + {file = "pyarrow-14.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2cc61593c8e66194c7cdfae594503e91b926a228fba40b5cf25cc593563bcd07"}, + {file = "pyarrow-14.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:78ea56f62fb7c0ae8ecb9afdd7893e3a7dbeb0b04106f5c08dbb23f9c0157591"}, + {file = "pyarrow-14.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:37c233ddbce0c67a76c0985612fef27c0c92aef9413cf5aa56952f359fcb7379"}, + {file = "pyarrow-14.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:e4b123ad0f6add92de898214d404e488167b87b5dd86e9a434126bc2b7a5578d"}, + {file = "pyarrow-14.0.2-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:e354fba8490de258be7687f341bc04aba181fc8aa1f71e4584f9890d9cb2dec2"}, + {file = "pyarrow-14.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:20e003a23a13da963f43e2b432483fdd8c38dc8882cd145f09f21792e1cf22a1"}, + {file = "pyarrow-14.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc0de7575e841f1595ac07e5bc631084fd06ca8b03c0f2ecece733d23cd5102a"}, + {file = "pyarrow-14.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66e986dc859712acb0bd45601229021f3ffcdfc49044b64c6d071aaf4fa49e98"}, + {file = "pyarrow-14.0.2-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:f7d029f20ef56673a9730766023459ece397a05001f4e4d13805111d7c2108c0"}, + {file = "pyarrow-14.0.2-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:209bac546942b0d8edc8debda248364f7f668e4aad4741bae58e67d40e5fcf75"}, + {file = "pyarrow-14.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:1e6987c5274fb87d66bb36816afb6f65707546b3c45c44c28e3c4133c010a881"}, + {file = "pyarrow-14.0.2-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:a01d0052d2a294a5f56cc1862933014e696aa08cc7b620e8c0cce5a5d362e976"}, + {file = "pyarrow-14.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a51fee3a7db4d37f8cda3ea96f32530620d43b0489d169b285d774da48ca9785"}, + {file = "pyarrow-14.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:64df2bf1ef2ef14cee531e2dfe03dd924017650ffaa6f9513d7a1bb291e59c15"}, + {file = "pyarrow-14.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c0fa3bfdb0305ffe09810f9d3e2e50a2787e3a07063001dcd7adae0cee3601a"}, + {file = "pyarrow-14.0.2-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:c65bf4fd06584f058420238bc47a316e80dda01ec0dfb3044594128a6c2db794"}, + {file = "pyarrow-14.0.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:63ac901baec9369d6aae1cbe6cca11178fb018a8d45068aaf5bb54f94804a866"}, + {file = "pyarrow-14.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:75ee0efe7a87a687ae303d63037d08a48ef9ea0127064df18267252cfe2e9541"}, + {file = "pyarrow-14.0.2.tar.gz", hash = "sha256:36cef6ba12b499d864d1def3e990f97949e0b79400d08b7cf74504ffbd3eb025"}, ] [package.dependencies] numpy = ">=1.16.6" -[[package]] -name = "pyarrow-hotfix" -version = "0.6" -description = "" -optional = false -python-versions = ">=3.5" -files = [ - {file = "pyarrow_hotfix-0.6-py3-none-any.whl", hash = "sha256:dcc9ae2d220dff0083be6a9aa8e0cdee5182ad358d4931fce825c545e5c89178"}, - {file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"}, -] - [[package]] name = "pycln" version = "2.4.0" description = "A formatter for finding and removing unused import statements." +category = "dev" optional = false python-versions = ">=3.7.0,<4" files = [ @@ -3042,6 +3200,7 @@ typer = ">=0.4.1" name = "pycparser" version = "2.21" description = "C parser in Python" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3053,6 +3212,7 @@ files = [ name = "pygments" version = "2.17.2" description = "Pygments is a syntax highlighting package written in Python." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3066,13 +3226,14 @@ windows-terminal = ["colorama (>=0.4.6)"] [[package]] name = "pytest" -version = "7.4.3" +version = "7.4.4" description = "pytest: simple powerful testing with Python" +category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "pytest-7.4.3-py3-none-any.whl", hash = "sha256:0d009c083ea859a71b76adf7c1d502e4bc170b80a8ef002da5806527b9591fac"}, - {file = "pytest-7.4.3.tar.gz", hash = "sha256:d989d136982de4e3b29dabcc838ad581c64e8ed52c11fbe86ddebd9da0818cd5"}, + {file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"}, + {file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"}, ] [package.dependencies] @@ -3090,6 +3251,7 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no name = "pytest-cov" version = "4.1.0" description = "Pytest plugin for measuring coverage." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3106,13 +3268,14 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtuale [[package]] name = "pytest-doctestplus" -version = "1.0.0" +version = "1.1.0" description = "Pytest plugin with advanced doctest features." +category = "dev" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "pytest-doctestplus-1.0.0.tar.gz", hash = "sha256:f650440dcaede13ed6d7da73bfb4ac585d40a80444ba3542d3e6eecdb275d49f"}, - {file = "pytest_doctestplus-1.0.0-py3-none-any.whl", hash = "sha256:dcba88e1e38bc4871c355e44b778ccfd49b25e33f6aa5393eed6b56440decb2a"}, + {file = "pytest-doctestplus-1.1.0.tar.gz", hash = "sha256:ea0a710f1b6a3571ed971fb6d6e5db05a2ae6b91b0fbcafe30fb5ea40e9987c4"}, + {file = "pytest_doctestplus-1.1.0-py3-none-any.whl", hash = "sha256:b98d95b4956a03256c638f1f9f72200160e9885ab1cd40f35c4453bc1d2e32b2"}, ] [package.dependencies] @@ -3127,6 +3290,7 @@ test = ["numpy", "pytest-remotedata (>=0.3.2)", "sphinx"] name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ @@ -3141,6 +3305,7 @@ six = ">=1.5" name = "python-json-logger" version = "2.0.7" description = "A python library adding a json log formatter" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3152,6 +3317,7 @@ files = [ name = "pytz" version = "2023.3.post1" description = "World timezone definitions, modern and historical" +category = "main" optional = false python-versions = "*" files = [ @@ -3163,6 +3329,7 @@ files = [ name = "pywin32" version = "306" description = "Python for Window Extensions" +category = "dev" optional = false python-versions = "*" files = [ @@ -3186,6 +3353,7 @@ files = [ name = "pywinpty" version = "2.0.12" description = "Pseudo terminal support for Windows from Python." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3201,6 +3369,7 @@ files = [ name = "pyyaml" version = "6.0.1" description = "YAML parser and emitter for Python" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3260,6 +3429,7 @@ files = [ name = "pyzmq" version = "25.1.2" description = "Python bindings for 0MQ" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3365,6 +3535,7 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} name = "qtconsole" version = "5.5.1" description = "Jupyter Qt console" +category = "dev" optional = false python-versions = ">= 3.8" files = [ @@ -3390,6 +3561,7 @@ test = ["flaky", "pytest", "pytest-qt"] name = "qtpy" version = "2.4.1" description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5/6 and PySide2/6)." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3405,13 +3577,14 @@ test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"] [[package]] name = "referencing" -version = "0.32.0" +version = "0.32.1" description = "JSON Referencing + Python" +category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "referencing-0.32.0-py3-none-any.whl", hash = "sha256:bdcd3efb936f82ff86f993093f6da7435c7de69a3b3a5a06678a6050184bee99"}, - {file = "referencing-0.32.0.tar.gz", hash = "sha256:689e64fe121843dcfd57b71933318ef1f91188ffb45367332700a86ac8fd6161"}, + {file = "referencing-0.32.1-py3-none-any.whl", hash = "sha256:7e4dc12271d8e15612bfe35792f5ea1c40970dadf8624602e33db2758f7ee554"}, + {file = "referencing-0.32.1.tar.gz", hash = "sha256:3c57da0513e9563eb7e203ebe9bb3a1b509b042016433bd1e45a2853466c3dd3"}, ] [package.dependencies] @@ -3420,105 +3593,112 @@ rpds-py = ">=0.7.0" [[package]] name = "regex" -version = "2023.10.3" +version = "2023.12.25" description = "Alternative regular expression module, to replace re." +category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "regex-2023.10.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4c34d4f73ea738223a094d8e0ffd6d2c1a1b4c175da34d6b0de3d8d69bee6bcc"}, - {file = "regex-2023.10.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a8f4e49fc3ce020f65411432183e6775f24e02dff617281094ba6ab079ef0915"}, - {file = "regex-2023.10.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4cd1bccf99d3ef1ab6ba835308ad85be040e6a11b0977ef7ea8c8005f01a3c29"}, - {file = "regex-2023.10.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:81dce2ddc9f6e8f543d94b05d56e70d03a0774d32f6cca53e978dc01e4fc75b8"}, - {file = "regex-2023.10.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c6b4d23c04831e3ab61717a707a5d763b300213db49ca680edf8bf13ab5d91b"}, - {file = "regex-2023.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c15ad0aee158a15e17e0495e1e18741573d04eb6da06d8b84af726cfc1ed02ee"}, - {file = "regex-2023.10.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6239d4e2e0b52c8bd38c51b760cd870069f0bdf99700a62cd509d7a031749a55"}, - {file = "regex-2023.10.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4a8bf76e3182797c6b1afa5b822d1d5802ff30284abe4599e1247be4fd6b03be"}, - {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d9c727bbcf0065cbb20f39d2b4f932f8fa1631c3e01fcedc979bd4f51fe051c5"}, - {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:3ccf2716add72f80714b9a63899b67fa711b654be3fcdd34fa391d2d274ce767"}, - {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:107ac60d1bfdc3edb53be75e2a52aff7481b92817cfdddd9b4519ccf0e54a6ff"}, - {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:00ba3c9818e33f1fa974693fb55d24cdc8ebafcb2e4207680669d8f8d7cca79a"}, - {file = "regex-2023.10.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f0a47efb1dbef13af9c9a54a94a0b814902e547b7f21acb29434504d18f36e3a"}, - {file = "regex-2023.10.3-cp310-cp310-win32.whl", hash = "sha256:36362386b813fa6c9146da6149a001b7bd063dabc4d49522a1f7aa65b725c7ec"}, - {file = "regex-2023.10.3-cp310-cp310-win_amd64.whl", hash = "sha256:c65a3b5330b54103e7d21cac3f6bf3900d46f6d50138d73343d9e5b2900b2353"}, - {file = "regex-2023.10.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:90a79bce019c442604662d17bf69df99090e24cdc6ad95b18b6725c2988a490e"}, - {file = "regex-2023.10.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c7964c2183c3e6cce3f497e3a9f49d182e969f2dc3aeeadfa18945ff7bdd7051"}, - {file = "regex-2023.10.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ef80829117a8061f974b2fda8ec799717242353bff55f8a29411794d635d964"}, - {file = "regex-2023.10.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5addc9d0209a9afca5fc070f93b726bf7003bd63a427f65ef797a931782e7edc"}, - {file = "regex-2023.10.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c148bec483cc4b421562b4bcedb8e28a3b84fcc8f0aa4418e10898f3c2c0eb9b"}, - {file = "regex-2023.10.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d1f21af4c1539051049796a0f50aa342f9a27cde57318f2fc41ed50b0dbc4ac"}, - {file = "regex-2023.10.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0b9ac09853b2a3e0d0082104036579809679e7715671cfbf89d83c1cb2a30f58"}, - {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ebedc192abbc7fd13c5ee800e83a6df252bec691eb2c4bedc9f8b2e2903f5e2a"}, - {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d8a993c0a0ffd5f2d3bda23d0cd75e7086736f8f8268de8a82fbc4bd0ac6791e"}, - {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:be6b7b8d42d3090b6c80793524fa66c57ad7ee3fe9722b258aec6d0672543fd0"}, - {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4023e2efc35a30e66e938de5aef42b520c20e7eda7bb5fb12c35e5d09a4c43f6"}, - {file = "regex-2023.10.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0d47840dc05e0ba04fe2e26f15126de7c755496d5a8aae4a08bda4dd8d646c54"}, - {file = "regex-2023.10.3-cp311-cp311-win32.whl", hash = "sha256:9145f092b5d1977ec8c0ab46e7b3381b2fd069957b9862a43bd383e5c01d18c2"}, - {file = "regex-2023.10.3-cp311-cp311-win_amd64.whl", hash = "sha256:b6104f9a46bd8743e4f738afef69b153c4b8b592d35ae46db07fc28ae3d5fb7c"}, - {file = "regex-2023.10.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:bff507ae210371d4b1fe316d03433ac099f184d570a1a611e541923f78f05037"}, - {file = "regex-2023.10.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:be5e22bbb67924dea15039c3282fa4cc6cdfbe0cbbd1c0515f9223186fc2ec5f"}, - {file = "regex-2023.10.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a992f702c9be9c72fa46f01ca6e18d131906a7180950958f766c2aa294d4b41"}, - {file = "regex-2023.10.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7434a61b158be563c1362d9071358f8ab91b8d928728cd2882af060481244c9e"}, - {file = "regex-2023.10.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c2169b2dcabf4e608416f7f9468737583ce5f0a6e8677c4efbf795ce81109d7c"}, - {file = "regex-2023.10.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9e908ef5889cda4de038892b9accc36d33d72fb3e12c747e2799a0e806ec841"}, - {file = "regex-2023.10.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:12bd4bc2c632742c7ce20db48e0d99afdc05e03f0b4c1af90542e05b809a03d9"}, - {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bc72c231f5449d86d6c7d9cc7cd819b6eb30134bb770b8cfdc0765e48ef9c420"}, - {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bce8814b076f0ce5766dc87d5a056b0e9437b8e0cd351b9a6c4e1134a7dfbda9"}, - {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:ba7cd6dc4d585ea544c1412019921570ebd8a597fabf475acc4528210d7c4a6f"}, - {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b0c7d2f698e83f15228ba41c135501cfe7d5740181d5903e250e47f617eb4292"}, - {file = "regex-2023.10.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5a8f91c64f390ecee09ff793319f30a0f32492e99f5dc1c72bc361f23ccd0a9a"}, - {file = "regex-2023.10.3-cp312-cp312-win32.whl", hash = "sha256:ad08a69728ff3c79866d729b095872afe1e0557251da4abb2c5faff15a91d19a"}, - {file = "regex-2023.10.3-cp312-cp312-win_amd64.whl", hash = "sha256:39cdf8d141d6d44e8d5a12a8569d5a227f645c87df4f92179bd06e2e2705e76b"}, - {file = "regex-2023.10.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4a3ee019a9befe84fa3e917a2dd378807e423d013377a884c1970a3c2792d293"}, - {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76066d7ff61ba6bf3cb5efe2428fc82aac91802844c022d849a1f0f53820502d"}, - {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bfe50b61bab1b1ec260fa7cd91106fa9fece57e6beba05630afe27c71259c59b"}, - {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9fd88f373cb71e6b59b7fa597e47e518282455c2734fd4306a05ca219a1991b0"}, - {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3ab05a182c7937fb374f7e946f04fb23a0c0699c0450e9fb02ef567412d2fa3"}, - {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dac37cf08fcf2094159922edc7a2784cfcc5c70f8354469f79ed085f0328ebdf"}, - {file = "regex-2023.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e54ddd0bb8fb626aa1f9ba7b36629564544954fff9669b15da3610c22b9a0991"}, - {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3367007ad1951fde612bf65b0dffc8fd681a4ab98ac86957d16491400d661302"}, - {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:16f8740eb6dbacc7113e3097b0a36065a02e37b47c936b551805d40340fb9971"}, - {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:f4f2ca6df64cbdd27f27b34f35adb640b5d2d77264228554e68deda54456eb11"}, - {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:39807cbcbe406efca2a233884e169d056c35aa7e9f343d4e78665246a332f597"}, - {file = "regex-2023.10.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:7eece6fbd3eae4a92d7c748ae825cbc1ee41a89bb1c3db05b5578ed3cfcfd7cb"}, - {file = "regex-2023.10.3-cp37-cp37m-win32.whl", hash = "sha256:ce615c92d90df8373d9e13acddd154152645c0dc060871abf6bd43809673d20a"}, - {file = "regex-2023.10.3-cp37-cp37m-win_amd64.whl", hash = "sha256:0f649fa32fe734c4abdfd4edbb8381c74abf5f34bc0b3271ce687b23729299ed"}, - {file = "regex-2023.10.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9b98b7681a9437262947f41c7fac567c7e1f6eddd94b0483596d320092004533"}, - {file = "regex-2023.10.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:91dc1d531f80c862441d7b66c4505cd6ea9d312f01fb2f4654f40c6fdf5cc37a"}, - {file = "regex-2023.10.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82fcc1f1cc3ff1ab8a57ba619b149b907072e750815c5ba63e7aa2e1163384a4"}, - {file = "regex-2023.10.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7979b834ec7a33aafae34a90aad9f914c41fd6eaa8474e66953f3f6f7cbd4368"}, - {file = "regex-2023.10.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ef71561f82a89af6cfcbee47f0fabfdb6e63788a9258e913955d89fdd96902ab"}, - {file = "regex-2023.10.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd829712de97753367153ed84f2de752b86cd1f7a88b55a3a775eb52eafe8a94"}, - {file = "regex-2023.10.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:00e871d83a45eee2f8688d7e6849609c2ca2a04a6d48fba3dff4deef35d14f07"}, - {file = "regex-2023.10.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:706e7b739fdd17cb89e1fbf712d9dc21311fc2333f6d435eac2d4ee81985098c"}, - {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:cc3f1c053b73f20c7ad88b0d1d23be7e7b3901229ce89f5000a8399746a6e039"}, - {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6f85739e80d13644b981a88f529d79c5bdf646b460ba190bffcaf6d57b2a9863"}, - {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:741ba2f511cc9626b7561a440f87d658aabb3d6b744a86a3c025f866b4d19e7f"}, - {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:e77c90ab5997e85901da85131fd36acd0ed2221368199b65f0d11bca44549711"}, - {file = "regex-2023.10.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:979c24cbefaf2420c4e377ecd1f165ea08cc3d1fbb44bdc51bccbbf7c66a2cb4"}, - {file = "regex-2023.10.3-cp38-cp38-win32.whl", hash = "sha256:58837f9d221744d4c92d2cf7201c6acd19623b50c643b56992cbd2b745485d3d"}, - {file = "regex-2023.10.3-cp38-cp38-win_amd64.whl", hash = "sha256:c55853684fe08d4897c37dfc5faeff70607a5f1806c8be148f1695be4a63414b"}, - {file = "regex-2023.10.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2c54e23836650bdf2c18222c87f6f840d4943944146ca479858404fedeb9f9af"}, - {file = "regex-2023.10.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:69c0771ca5653c7d4b65203cbfc5e66db9375f1078689459fe196fe08b7b4930"}, - {file = "regex-2023.10.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ac965a998e1388e6ff2e9781f499ad1eaa41e962a40d11c7823c9952c77123e"}, - {file = "regex-2023.10.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1c0e8fae5b27caa34177bdfa5a960c46ff2f78ee2d45c6db15ae3f64ecadde14"}, - {file = "regex-2023.10.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6c56c3d47da04f921b73ff9415fbaa939f684d47293f071aa9cbb13c94afc17d"}, - {file = "regex-2023.10.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ef1e014eed78ab650bef9a6a9cbe50b052c0aebe553fb2881e0453717573f52"}, - {file = "regex-2023.10.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d29338556a59423d9ff7b6eb0cb89ead2b0875e08fe522f3e068b955c3e7b59b"}, - {file = "regex-2023.10.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:9c6d0ced3c06d0f183b73d3c5920727268d2201aa0fe6d55c60d68c792ff3588"}, - {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:994645a46c6a740ee8ce8df7911d4aee458d9b1bc5639bc968226763d07f00fa"}, - {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:66e2fe786ef28da2b28e222c89502b2af984858091675044d93cb50e6f46d7af"}, - {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:11175910f62b2b8c055f2b089e0fedd694fe2be3941b3e2633653bc51064c528"}, - {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:06e9abc0e4c9ab4779c74ad99c3fc10d3967d03114449acc2c2762ad4472b8ca"}, - {file = "regex-2023.10.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fb02e4257376ae25c6dd95a5aec377f9b18c09be6ebdefa7ad209b9137b73d48"}, - {file = "regex-2023.10.3-cp39-cp39-win32.whl", hash = "sha256:3b2c3502603fab52d7619b882c25a6850b766ebd1b18de3df23b2f939360e1bd"}, - {file = "regex-2023.10.3-cp39-cp39-win_amd64.whl", hash = "sha256:adbccd17dcaff65704c856bd29951c58a1bd4b2b0f8ad6b826dbd543fe740988"}, - {file = "regex-2023.10.3.tar.gz", hash = "sha256:3fef4f844d2290ee0ba57addcec17eec9e3df73f10a2748485dfd6a3a188cc0f"}, + {file = "regex-2023.12.25-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0694219a1d54336fd0445ea382d49d36882415c0134ee1e8332afd1529f0baa5"}, + {file = "regex-2023.12.25-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b014333bd0217ad3d54c143de9d4b9a3ca1c5a29a6d0d554952ea071cff0f1f8"}, + {file = "regex-2023.12.25-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d865984b3f71f6d0af64d0d88f5733521698f6c16f445bb09ce746c92c97c586"}, + {file = "regex-2023.12.25-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1e0eabac536b4cc7f57a5f3d095bfa557860ab912f25965e08fe1545e2ed8b4c"}, + {file = "regex-2023.12.25-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c25a8ad70e716f96e13a637802813f65d8a6760ef48672aa3502f4c24ea8b400"}, + {file = "regex-2023.12.25-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a9b6d73353f777630626f403b0652055ebfe8ff142a44ec2cf18ae470395766e"}, + {file = "regex-2023.12.25-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9cc99d6946d750eb75827cb53c4371b8b0fe89c733a94b1573c9dd16ea6c9e4"}, + {file = "regex-2023.12.25-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:88d1f7bef20c721359d8675f7d9f8e414ec5003d8f642fdfd8087777ff7f94b5"}, + {file = "regex-2023.12.25-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:cb3fe77aec8f1995611f966d0c656fdce398317f850d0e6e7aebdfe61f40e1cd"}, + {file = "regex-2023.12.25-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7aa47c2e9ea33a4a2a05f40fcd3ea36d73853a2aae7b4feab6fc85f8bf2c9704"}, + {file = "regex-2023.12.25-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:df26481f0c7a3f8739fecb3e81bc9da3fcfae34d6c094563b9d4670b047312e1"}, + {file = "regex-2023.12.25-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:c40281f7d70baf6e0db0c2f7472b31609f5bc2748fe7275ea65a0b4601d9b392"}, + {file = "regex-2023.12.25-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:d94a1db462d5690ebf6ae86d11c5e420042b9898af5dcf278bd97d6bda065423"}, + {file = "regex-2023.12.25-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ba1b30765a55acf15dce3f364e4928b80858fa8f979ad41f862358939bdd1f2f"}, + {file = "regex-2023.12.25-cp310-cp310-win32.whl", hash = "sha256:150c39f5b964e4d7dba46a7962a088fbc91f06e606f023ce57bb347a3b2d4630"}, + {file = "regex-2023.12.25-cp310-cp310-win_amd64.whl", hash = "sha256:09da66917262d9481c719599116c7dc0c321ffcec4b1f510c4f8a066f8768105"}, + {file = "regex-2023.12.25-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:1b9d811f72210fa9306aeb88385b8f8bcef0dfbf3873410413c00aa94c56c2b6"}, + {file = "regex-2023.12.25-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d902a43085a308cef32c0d3aea962524b725403fd9373dea18110904003bac97"}, + {file = "regex-2023.12.25-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d166eafc19f4718df38887b2bbe1467a4f74a9830e8605089ea7a30dd4da8887"}, + {file = "regex-2023.12.25-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c7ad32824b7f02bb3c9f80306d405a1d9b7bb89362d68b3c5a9be53836caebdb"}, + {file = "regex-2023.12.25-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:636ba0a77de609d6510235b7f0e77ec494d2657108f777e8765efc060094c98c"}, + {file = "regex-2023.12.25-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fda75704357805eb953a3ee15a2b240694a9a514548cd49b3c5124b4e2ad01b"}, + {file = "regex-2023.12.25-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f72cbae7f6b01591f90814250e636065850c5926751af02bb48da94dfced7baa"}, + {file = "regex-2023.12.25-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:db2a0b1857f18b11e3b0e54ddfefc96af46b0896fb678c85f63fb8c37518b3e7"}, + {file = "regex-2023.12.25-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:7502534e55c7c36c0978c91ba6f61703faf7ce733715ca48f499d3dbbd7657e0"}, + {file = "regex-2023.12.25-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:e8c7e08bb566de4faaf11984af13f6bcf6a08f327b13631d41d62592681d24fe"}, + {file = "regex-2023.12.25-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:283fc8eed679758de38fe493b7d7d84a198b558942b03f017b1f94dda8efae80"}, + {file = "regex-2023.12.25-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:f44dd4d68697559d007462b0a3a1d9acd61d97072b71f6d1968daef26bc744bd"}, + {file = "regex-2023.12.25-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:67d3ccfc590e5e7197750fcb3a2915b416a53e2de847a728cfa60141054123d4"}, + {file = "regex-2023.12.25-cp311-cp311-win32.whl", hash = "sha256:68191f80a9bad283432385961d9efe09d783bcd36ed35a60fb1ff3f1ec2efe87"}, + {file = "regex-2023.12.25-cp311-cp311-win_amd64.whl", hash = "sha256:7d2af3f6b8419661a0c421584cfe8aaec1c0e435ce7e47ee2a97e344b98f794f"}, + {file = "regex-2023.12.25-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:8a0ccf52bb37d1a700375a6b395bff5dd15c50acb745f7db30415bae3c2b0715"}, + {file = "regex-2023.12.25-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c3c4a78615b7762740531c27cf46e2f388d8d727d0c0c739e72048beb26c8a9d"}, + {file = "regex-2023.12.25-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ad83e7545b4ab69216cef4cc47e344d19622e28aabec61574b20257c65466d6a"}, + {file = "regex-2023.12.25-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7a635871143661feccce3979e1727c4e094f2bdfd3ec4b90dfd4f16f571a87a"}, + {file = "regex-2023.12.25-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d498eea3f581fbe1b34b59c697512a8baef88212f92e4c7830fcc1499f5b45a5"}, + {file = "regex-2023.12.25-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:43f7cd5754d02a56ae4ebb91b33461dc67be8e3e0153f593c509e21d219c5060"}, + {file = "regex-2023.12.25-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51f4b32f793812714fd5307222a7f77e739b9bc566dc94a18126aba3b92b98a3"}, + {file = "regex-2023.12.25-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba99d8077424501b9616b43a2d208095746fb1284fc5ba490139651f971d39d9"}, + {file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4bfc2b16e3ba8850e0e262467275dd4d62f0d045e0e9eda2bc65078c0110a11f"}, + {file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8c2c19dae8a3eb0ea45a8448356ed561be843b13cbc34b840922ddf565498c1c"}, + {file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:60080bb3d8617d96f0fb7e19796384cc2467447ef1c491694850ebd3670bc457"}, + {file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b77e27b79448e34c2c51c09836033056a0547aa360c45eeeb67803da7b0eedaf"}, + {file = "regex-2023.12.25-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:518440c991f514331f4850a63560321f833979d145d7d81186dbe2f19e27ae3d"}, + {file = "regex-2023.12.25-cp312-cp312-win32.whl", hash = "sha256:e2610e9406d3b0073636a3a2e80db05a02f0c3169b5632022b4e81c0364bcda5"}, + {file = "regex-2023.12.25-cp312-cp312-win_amd64.whl", hash = "sha256:cc37b9aeebab425f11f27e5e9e6cf580be7206c6582a64467a14dda211abc232"}, + {file = "regex-2023.12.25-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:da695d75ac97cb1cd725adac136d25ca687da4536154cdc2815f576e4da11c69"}, + {file = "regex-2023.12.25-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d126361607b33c4eb7b36debc173bf25d7805847346dd4d99b5499e1fef52bc7"}, + {file = "regex-2023.12.25-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4719bb05094d7d8563a450cf8738d2e1061420f79cfcc1fa7f0a44744c4d8f73"}, + {file = "regex-2023.12.25-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5dd58946bce44b53b06d94aa95560d0b243eb2fe64227cba50017a8d8b3cd3e2"}, + {file = "regex-2023.12.25-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22a86d9fff2009302c440b9d799ef2fe322416d2d58fc124b926aa89365ec482"}, + {file = "regex-2023.12.25-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2aae8101919e8aa05ecfe6322b278f41ce2994c4a430303c4cd163fef746e04f"}, + {file = "regex-2023.12.25-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:e692296c4cc2873967771345a876bcfc1c547e8dd695c6b89342488b0ea55cd8"}, + {file = "regex-2023.12.25-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:263ef5cc10979837f243950637fffb06e8daed7f1ac1e39d5910fd29929e489a"}, + {file = "regex-2023.12.25-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:d6f7e255e5fa94642a0724e35406e6cb7001c09d476ab5fce002f652b36d0c39"}, + {file = "regex-2023.12.25-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:88ad44e220e22b63b0f8f81f007e8abbb92874d8ced66f32571ef8beb0643b2b"}, + {file = "regex-2023.12.25-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:3a17d3ede18f9cedcbe23d2daa8a2cd6f59fe2bf082c567e43083bba3fb00347"}, + {file = "regex-2023.12.25-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d15b274f9e15b1a0b7a45d2ac86d1f634d983ca40d6b886721626c47a400bf39"}, + {file = "regex-2023.12.25-cp37-cp37m-win32.whl", hash = "sha256:ed19b3a05ae0c97dd8f75a5d8f21f7723a8c33bbc555da6bbe1f96c470139d3c"}, + {file = "regex-2023.12.25-cp37-cp37m-win_amd64.whl", hash = "sha256:a6d1047952c0b8104a1d371f88f4ab62e6275567d4458c1e26e9627ad489b445"}, + {file = "regex-2023.12.25-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:b43523d7bc2abd757119dbfb38af91b5735eea45537ec6ec3a5ec3f9562a1c53"}, + {file = "regex-2023.12.25-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:efb2d82f33b2212898f1659fb1c2e9ac30493ac41e4d53123da374c3b5541e64"}, + {file = "regex-2023.12.25-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b7fca9205b59c1a3d5031f7e64ed627a1074730a51c2a80e97653e3e9fa0d415"}, + {file = "regex-2023.12.25-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:086dd15e9435b393ae06f96ab69ab2d333f5d65cbe65ca5a3ef0ec9564dfe770"}, + {file = "regex-2023.12.25-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e81469f7d01efed9b53740aedd26085f20d49da65f9c1f41e822a33992cb1590"}, + {file = "regex-2023.12.25-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:34e4af5b27232f68042aa40a91c3b9bb4da0eeb31b7632e0091afc4310afe6cb"}, + {file = "regex-2023.12.25-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9852b76ab558e45b20bf1893b59af64a28bd3820b0c2efc80e0a70a4a3ea51c1"}, + {file = "regex-2023.12.25-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ff100b203092af77d1a5a7abe085b3506b7eaaf9abf65b73b7d6905b6cb76988"}, + {file = "regex-2023.12.25-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:cc038b2d8b1470364b1888a98fd22d616fba2b6309c5b5f181ad4483e0017861"}, + {file = "regex-2023.12.25-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:094ba386bb5c01e54e14434d4caabf6583334090865b23ef58e0424a6286d3dc"}, + {file = "regex-2023.12.25-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5cd05d0f57846d8ba4b71d9c00f6f37d6b97d5e5ef8b3c3840426a475c8f70f4"}, + {file = "regex-2023.12.25-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:9aa1a67bbf0f957bbe096375887b2505f5d8ae16bf04488e8b0f334c36e31360"}, + {file = "regex-2023.12.25-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:98a2636994f943b871786c9e82bfe7883ecdaba2ef5df54e1450fa9869d1f756"}, + {file = "regex-2023.12.25-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:37f8e93a81fc5e5bd8db7e10e62dc64261bcd88f8d7e6640aaebe9bc180d9ce2"}, + {file = "regex-2023.12.25-cp38-cp38-win32.whl", hash = "sha256:d78bd484930c1da2b9679290a41cdb25cc127d783768a0369d6b449e72f88beb"}, + {file = "regex-2023.12.25-cp38-cp38-win_amd64.whl", hash = "sha256:b521dcecebc5b978b447f0f69b5b7f3840eac454862270406a39837ffae4e697"}, + {file = "regex-2023.12.25-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f7bc09bc9c29ebead055bcba136a67378f03d66bf359e87d0f7c759d6d4ffa31"}, + {file = "regex-2023.12.25-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e14b73607d6231f3cc4622809c196b540a6a44e903bcfad940779c80dffa7be7"}, + {file = "regex-2023.12.25-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9eda5f7a50141291beda3edd00abc2d4a5b16c29c92daf8d5bd76934150f3edc"}, + {file = "regex-2023.12.25-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc6bb9aa69aacf0f6032c307da718f61a40cf970849e471254e0e91c56ffca95"}, + {file = "regex-2023.12.25-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:298dc6354d414bc921581be85695d18912bea163a8b23cac9a2562bbcd5088b1"}, + {file = "regex-2023.12.25-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2f4e475a80ecbd15896a976aa0b386c5525d0ed34d5c600b6d3ebac0a67c7ddf"}, + {file = "regex-2023.12.25-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:531ac6cf22b53e0696f8e1d56ce2396311254eb806111ddd3922c9d937151dae"}, + {file = "regex-2023.12.25-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22f3470f7524b6da61e2020672df2f3063676aff444db1daa283c2ea4ed259d6"}, + {file = "regex-2023.12.25-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:89723d2112697feaa320c9d351e5f5e7b841e83f8b143dba8e2d2b5f04e10923"}, + {file = "regex-2023.12.25-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0ecf44ddf9171cd7566ef1768047f6e66975788258b1c6c6ca78098b95cf9a3d"}, + {file = "regex-2023.12.25-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:905466ad1702ed4acfd67a902af50b8db1feeb9781436372261808df7a2a7bca"}, + {file = "regex-2023.12.25-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:4558410b7a5607a645e9804a3e9dd509af12fb72b9825b13791a37cd417d73a5"}, + {file = "regex-2023.12.25-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:7e316026cc1095f2a3e8cc012822c99f413b702eaa2ca5408a513609488cb62f"}, + {file = "regex-2023.12.25-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3b1de218d5375cd6ac4b5493e0b9f3df2be331e86520f23382f216c137913d20"}, + {file = "regex-2023.12.25-cp39-cp39-win32.whl", hash = "sha256:11a963f8e25ab5c61348d090bf1b07f1953929c13bd2309a0662e9ff680763c9"}, + {file = "regex-2023.12.25-cp39-cp39-win_amd64.whl", hash = "sha256:e693e233ac92ba83a87024e1d32b5f9ab15ca55ddd916d878146f4e3406b5c91"}, + {file = "regex-2023.12.25.tar.gz", hash = "sha256:29171aa128da69afdf4bde412d5bedc335f2ca8fcfe4489038577d05f16181e5"}, ] [[package]] name = "requests" version = "2.31.0" description = "Python HTTP for Humans." +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3540,6 +3720,7 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] name = "rfc3339-validator" version = "0.1.4" description = "A pure python RFC3339 validator" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -3554,6 +3735,7 @@ six = "*" name = "rfc3986-validator" version = "0.1.1" description = "Pure python rfc3986 validator" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -3565,6 +3747,7 @@ files = [ name = "rich" version = "13.7.0" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -3582,116 +3765,118 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] [[package]] name = "rpds-py" -version = "0.13.2" +version = "0.16.2" description = "Python bindings to Rust's persistent data structures (rpds)" +category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "rpds_py-0.13.2-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:1ceebd0ae4f3e9b2b6b553b51971921853ae4eebf3f54086be0565d59291e53d"}, - {file = "rpds_py-0.13.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:46e1ed994a0920f350a4547a38471217eb86f57377e9314fbaaa329b71b7dfe3"}, - {file = "rpds_py-0.13.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee353bb51f648924926ed05e0122b6a0b1ae709396a80eb583449d5d477fcdf7"}, - {file = "rpds_py-0.13.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:530190eb0cd778363bbb7596612ded0bb9fef662daa98e9d92a0419ab27ae914"}, - {file = "rpds_py-0.13.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29d311e44dd16d2434d5506d57ef4d7036544fc3c25c14b6992ef41f541b10fb"}, - {file = "rpds_py-0.13.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e72f750048b32d39e87fc85c225c50b2a6715034848dbb196bf3348aa761fa1"}, - {file = "rpds_py-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db09b98c7540df69d4b47218da3fbd7cb466db0fb932e971c321f1c76f155266"}, - {file = "rpds_py-0.13.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2ac26f50736324beb0282c819668328d53fc38543fa61eeea2c32ea8ea6eab8d"}, - {file = "rpds_py-0.13.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:12ecf89bd54734c3c2c79898ae2021dca42750c7bcfb67f8fb3315453738ac8f"}, - {file = "rpds_py-0.13.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3a44c8440183b43167fd1a0819e8356692bf5db1ad14ce140dbd40a1485f2dea"}, - {file = "rpds_py-0.13.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bcef4f2d3dc603150421de85c916da19471f24d838c3c62a4f04c1eb511642c1"}, - {file = "rpds_py-0.13.2-cp310-none-win32.whl", hash = "sha256:ee6faebb265e28920a6f23a7d4c362414b3f4bb30607141d718b991669e49ddc"}, - {file = "rpds_py-0.13.2-cp310-none-win_amd64.whl", hash = "sha256:ac96d67b37f28e4b6ecf507c3405f52a40658c0a806dffde624a8fcb0314d5fd"}, - {file = "rpds_py-0.13.2-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:b5f6328e8e2ae8238fc767703ab7b95785521c42bb2b8790984e3477d7fa71ad"}, - {file = "rpds_py-0.13.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:729408136ef8d45a28ee9a7411917c9e3459cf266c7e23c2f7d4bb8ef9e0da42"}, - {file = "rpds_py-0.13.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:65cfed9c807c27dee76407e8bb29e6f4e391e436774bcc769a037ff25ad8646e"}, - {file = "rpds_py-0.13.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:aefbdc934115d2f9278f153952003ac52cd2650e7313750390b334518c589568"}, - {file = "rpds_py-0.13.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d48db29bd47814671afdd76c7652aefacc25cf96aad6daefa82d738ee87461e2"}, - {file = "rpds_py-0.13.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3c55d7f2d817183d43220738270efd3ce4e7a7b7cbdaefa6d551ed3d6ed89190"}, - {file = "rpds_py-0.13.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6aadae3042f8e6db3376d9e91f194c606c9a45273c170621d46128f35aef7cd0"}, - {file = "rpds_py-0.13.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5feae2f9aa7270e2c071f488fab256d768e88e01b958f123a690f1cc3061a09c"}, - {file = "rpds_py-0.13.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:51967a67ea0d7b9b5cd86036878e2d82c0b6183616961c26d825b8c994d4f2c8"}, - {file = "rpds_py-0.13.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4d0c10d803549427f427085ed7aebc39832f6e818a011dcd8785e9c6a1ba9b3e"}, - {file = "rpds_py-0.13.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:603d5868f7419081d616dab7ac3cfa285296735e7350f7b1e4f548f6f953ee7d"}, - {file = "rpds_py-0.13.2-cp311-none-win32.whl", hash = "sha256:b8996ffb60c69f677245f5abdbcc623e9442bcc91ed81b6cd6187129ad1fa3e7"}, - {file = "rpds_py-0.13.2-cp311-none-win_amd64.whl", hash = "sha256:5379e49d7e80dca9811b36894493d1c1ecb4c57de05c36f5d0dd09982af20211"}, - {file = "rpds_py-0.13.2-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:8a776a29b77fe0cc28fedfd87277b0d0f7aa930174b7e504d764e0b43a05f381"}, - {file = "rpds_py-0.13.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2a1472956c5bcc49fb0252b965239bffe801acc9394f8b7c1014ae9258e4572b"}, - {file = "rpds_py-0.13.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f252dfb4852a527987a9156cbcae3022a30f86c9d26f4f17b8c967d7580d65d2"}, - {file = "rpds_py-0.13.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f0d320e70b6b2300ff6029e234e79fe44e9dbbfc7b98597ba28e054bd6606a57"}, - {file = "rpds_py-0.13.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ade2ccb937060c299ab0dfb2dea3d2ddf7e098ed63ee3d651ebfc2c8d1e8632a"}, - {file = "rpds_py-0.13.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b9d121be0217787a7d59a5c6195b0842d3f701007333426e5154bf72346aa658"}, - {file = "rpds_py-0.13.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8fa6bd071ec6d90f6e7baa66ae25820d57a8ab1b0a3c6d3edf1834d4b26fafa2"}, - {file = "rpds_py-0.13.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c918621ee0a3d1fe61c313f2489464f2ae3d13633e60f520a8002a5e910982ee"}, - {file = "rpds_py-0.13.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:25b28b3d33ec0a78e944aaaed7e5e2a94ac811bcd68b557ca48a0c30f87497d2"}, - {file = "rpds_py-0.13.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:31e220a040b89a01505128c2f8a59ee74732f666439a03e65ccbf3824cdddae7"}, - {file = "rpds_py-0.13.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:15253fff410873ebf3cfba1cc686a37711efcd9b8cb30ea21bb14a973e393f60"}, - {file = "rpds_py-0.13.2-cp312-none-win32.whl", hash = "sha256:b981a370f8f41c4024c170b42fbe9e691ae2dbc19d1d99151a69e2c84a0d194d"}, - {file = "rpds_py-0.13.2-cp312-none-win_amd64.whl", hash = "sha256:4c4e314d36d4f31236a545696a480aa04ea170a0b021e9a59ab1ed94d4c3ef27"}, - {file = "rpds_py-0.13.2-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:80e5acb81cb49fd9f2d5c08f8b74ffff14ee73b10ca88297ab4619e946bcb1e1"}, - {file = "rpds_py-0.13.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:efe093acc43e869348f6f2224df7f452eab63a2c60a6c6cd6b50fd35c4e075ba"}, - {file = "rpds_py-0.13.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c2a61c0e4811012b0ba9f6cdcb4437865df5d29eab5d6018ba13cee1c3064a0"}, - {file = "rpds_py-0.13.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:751758d9dd04d548ec679224cc00e3591f5ebf1ff159ed0d4aba6a0746352452"}, - {file = "rpds_py-0.13.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6ba8858933f0c1a979781272a5f65646fca8c18c93c99c6ddb5513ad96fa54b1"}, - {file = "rpds_py-0.13.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bfdfbe6a36bc3059fff845d64c42f2644cf875c65f5005db54f90cdfdf1df815"}, - {file = "rpds_py-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa0379c1935c44053c98826bc99ac95f3a5355675a297ac9ce0dfad0ce2d50ca"}, - {file = "rpds_py-0.13.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d5593855b5b2b73dd8413c3fdfa5d95b99d657658f947ba2c4318591e745d083"}, - {file = "rpds_py-0.13.2-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:2a7bef6977043673750a88da064fd513f89505111014b4e00fbdd13329cd4e9a"}, - {file = "rpds_py-0.13.2-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:3ab96754d23372009638a402a1ed12a27711598dd49d8316a22597141962fe66"}, - {file = "rpds_py-0.13.2-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:e06cfea0ece444571d24c18ed465bc93afb8c8d8d74422eb7026662f3d3f779b"}, - {file = "rpds_py-0.13.2-cp38-none-win32.whl", hash = "sha256:5493569f861fb7b05af6d048d00d773c6162415ae521b7010197c98810a14cab"}, - {file = "rpds_py-0.13.2-cp38-none-win_amd64.whl", hash = "sha256:b07501b720cf060c5856f7b5626e75b8e353b5f98b9b354a21eb4bfa47e421b1"}, - {file = "rpds_py-0.13.2-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:881df98f0a8404d32b6de0fd33e91c1b90ed1516a80d4d6dc69d414b8850474c"}, - {file = "rpds_py-0.13.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d79c159adea0f1f4617f54aa156568ac69968f9ef4d1e5fefffc0a180830308e"}, - {file = "rpds_py-0.13.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38d4f822ee2f338febcc85aaa2547eb5ba31ba6ff68d10b8ec988929d23bb6b4"}, - {file = "rpds_py-0.13.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5d75d6d220d55cdced2f32cc22f599475dbe881229aeddba6c79c2e9df35a2b3"}, - {file = "rpds_py-0.13.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5d97e9ae94fb96df1ee3cb09ca376c34e8a122f36927230f4c8a97f469994bff"}, - {file = "rpds_py-0.13.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:67a429520e97621a763cf9b3ba27574779c4e96e49a27ff8a1aa99ee70beb28a"}, - {file = "rpds_py-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:188435794405c7f0573311747c85a96b63c954a5f2111b1df8018979eca0f2f0"}, - {file = "rpds_py-0.13.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:38f9bf2ad754b4a45b8210a6c732fe876b8a14e14d5992a8c4b7c1ef78740f53"}, - {file = "rpds_py-0.13.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:a6ba2cb7d676e9415b9e9ac7e2aae401dc1b1e666943d1f7bc66223d3d73467b"}, - {file = "rpds_py-0.13.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:eaffbd8814bb1b5dc3ea156a4c5928081ba50419f9175f4fc95269e040eff8f0"}, - {file = "rpds_py-0.13.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5a4c1058cdae6237d97af272b326e5f78ee7ee3bbffa6b24b09db4d828810468"}, - {file = "rpds_py-0.13.2-cp39-none-win32.whl", hash = "sha256:b5267feb19070bef34b8dea27e2b504ebd9d31748e3ecacb3a4101da6fcb255c"}, - {file = "rpds_py-0.13.2-cp39-none-win_amd64.whl", hash = "sha256:ddf23960cb42b69bce13045d5bc66f18c7d53774c66c13f24cf1b9c144ba3141"}, - {file = "rpds_py-0.13.2-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:97163a1ab265a1073a6372eca9f4eeb9f8c6327457a0b22ddfc4a17dcd613e74"}, - {file = "rpds_py-0.13.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:25ea41635d22b2eb6326f58e608550e55d01df51b8a580ea7e75396bafbb28e9"}, - {file = "rpds_py-0.13.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76d59d4d451ba77f08cb4cd9268dec07be5bc65f73666302dbb5061989b17198"}, - {file = "rpds_py-0.13.2-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e7c564c58cf8f248fe859a4f0fe501b050663f3d7fbc342172f259124fb59933"}, - {file = "rpds_py-0.13.2-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61dbc1e01dc0c5875da2f7ae36d6e918dc1b8d2ce04e871793976594aad8a57a"}, - {file = "rpds_py-0.13.2-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fdb82eb60d31b0c033a8e8ee9f3fc7dfbaa042211131c29da29aea8531b4f18f"}, - {file = "rpds_py-0.13.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d204957169f0b3511fb95395a9da7d4490fb361763a9f8b32b345a7fe119cb45"}, - {file = "rpds_py-0.13.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c45008ca79bad237cbc03c72bc5205e8c6f66403773929b1b50f7d84ef9e4d07"}, - {file = "rpds_py-0.13.2-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:79bf58c08f0756adba691d480b5a20e4ad23f33e1ae121584cf3a21717c36dfa"}, - {file = "rpds_py-0.13.2-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:e86593bf8637659e6a6ed58854b6c87ec4e9e45ee8a4adfd936831cef55c2d21"}, - {file = "rpds_py-0.13.2-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:d329896c40d9e1e5c7715c98529e4a188a1f2df51212fd65102b32465612b5dc"}, - {file = "rpds_py-0.13.2-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:4a5375c5fff13f209527cd886dc75394f040c7d1ecad0a2cb0627f13ebe78a12"}, - {file = "rpds_py-0.13.2-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:06d218e4464d31301e943b65b2c6919318ea6f69703a351961e1baaf60347276"}, - {file = "rpds_py-0.13.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1f41d32a2ddc5a94df4b829b395916a4b7f103350fa76ba6de625fcb9e773ac"}, - {file = "rpds_py-0.13.2-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6bc568b05e02cd612be53900c88aaa55012e744930ba2eeb56279db4c6676eb3"}, - {file = "rpds_py-0.13.2-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d94d78418203904730585efa71002286ac4c8ac0689d0eb61e3c465f9e608ff"}, - {file = "rpds_py-0.13.2-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bed0252c85e21cf73d2d033643c945b460d6a02fc4a7d644e3b2d6f5f2956c64"}, - {file = "rpds_py-0.13.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:244e173bb6d8f3b2f0c4d7370a1aa341f35da3e57ffd1798e5b2917b91731fd3"}, - {file = "rpds_py-0.13.2-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7f55cd9cf1564b7b03f238e4c017ca4794c05b01a783e9291065cb2858d86ce4"}, - {file = "rpds_py-0.13.2-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:f03a1b3a4c03e3e0161642ac5367f08479ab29972ea0ffcd4fa18f729cd2be0a"}, - {file = "rpds_py-0.13.2-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:f5f4424cb87a20b016bfdc157ff48757b89d2cc426256961643d443c6c277007"}, - {file = "rpds_py-0.13.2-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:c82bbf7e03748417c3a88c1b0b291288ce3e4887a795a3addaa7a1cfd9e7153e"}, - {file = "rpds_py-0.13.2-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:c0095b8aa3e432e32d372e9a7737e65b58d5ed23b9620fea7cb81f17672f1fa1"}, - {file = "rpds_py-0.13.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:4c2d26aa03d877c9730bf005621c92da263523a1e99247590abbbe252ccb7824"}, - {file = "rpds_py-0.13.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96f2975fb14f39c5fe75203f33dd3010fe37d1c4e33177feef1107b5ced750e3"}, - {file = "rpds_py-0.13.2-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4dcc5ee1d0275cb78d443fdebd0241e58772a354a6d518b1d7af1580bbd2c4e8"}, - {file = "rpds_py-0.13.2-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61d42d2b08430854485135504f672c14d4fc644dd243a9c17e7c4e0faf5ed07e"}, - {file = "rpds_py-0.13.2-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d3a61e928feddc458a55110f42f626a2a20bea942ccedb6fb4cee70b4830ed41"}, - {file = "rpds_py-0.13.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7de12b69d95072394998c622cfd7e8cea8f560db5fca6a62a148f902a1029f8b"}, - {file = "rpds_py-0.13.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:87a90f5545fd61f6964e65eebde4dc3fa8660bb7d87adb01d4cf17e0a2b484ad"}, - {file = "rpds_py-0.13.2-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:9c95a1a290f9acf7a8f2ebbdd183e99215d491beea52d61aa2a7a7d2c618ddc6"}, - {file = "rpds_py-0.13.2-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:35f53c76a712e323c779ca39b9a81b13f219a8e3bc15f106ed1e1462d56fcfe9"}, - {file = "rpds_py-0.13.2-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:96fb0899bb2ab353f42e5374c8f0789f54e0a94ef2f02b9ac7149c56622eaf31"}, - {file = "rpds_py-0.13.2.tar.gz", hash = "sha256:f8eae66a1304de7368932b42d801c67969fd090ddb1a7a24f27b435ed4bed68f"}, + {file = "rpds_py-0.16.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:509b617ac787cd1149600e731db9274ebbef094503ca25158e6f23edaba1ca8f"}, + {file = "rpds_py-0.16.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:413b9c17388bbd0d87a329d8e30c1a4c6e44e2bb25457f43725a8e6fe4161e9e"}, + {file = "rpds_py-0.16.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2946b120718eba9af2b4dd103affc1164a87b9e9ebff8c3e4c05d7b7a7e274e2"}, + {file = "rpds_py-0.16.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:35ae5ece284cf36464eb160880018cf6088a9ac5ddc72292a6092b6ef3f4da53"}, + {file = "rpds_py-0.16.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3dc6a7620ba7639a3db6213da61312cb4aa9ac0ca6e00dc1cbbdc21c2aa6eb57"}, + {file = "rpds_py-0.16.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8cb6fe8ecdfffa0e711a75c931fb39f4ba382b4b3ccedeca43f18693864fe850"}, + {file = "rpds_py-0.16.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6dace7b26a13353e24613417ce2239491b40a6ad44e5776a18eaff7733488b44"}, + {file = "rpds_py-0.16.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1bdbc5fcb04a7309074de6b67fa9bc4b418ab3fc435fec1f2779a0eced688d04"}, + {file = "rpds_py-0.16.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f42e25c016927e2a6b1ce748112c3ab134261fc2ddc867e92d02006103e1b1b7"}, + {file = "rpds_py-0.16.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:eab36eae3f3e8e24b05748ec9acc66286662f5d25c52ad70cadab544e034536b"}, + {file = "rpds_py-0.16.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:0474df4ade9a3b4af96c3d36eb81856cb9462e4c6657d4caecfd840d2a13f3c9"}, + {file = "rpds_py-0.16.2-cp310-none-win32.whl", hash = "sha256:84c5a4d1f9dd7e2d2c44097fb09fffe728629bad31eb56caf97719e55575aa82"}, + {file = "rpds_py-0.16.2-cp310-none-win_amd64.whl", hash = "sha256:2bd82db36cd70b3628c0c57d81d2438e8dd4b7b32a6a9f25f24ab0e657cb6c4e"}, + {file = "rpds_py-0.16.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:adc0c3d6fc6ae35fee3e4917628983f6ce630d513cbaad575b4517d47e81b4bb"}, + {file = "rpds_py-0.16.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ec23fcad480e77ede06cf4127a25fc440f7489922e17fc058f426b5256ee0edb"}, + {file = "rpds_py-0.16.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07aab64e2808c3ebac2a44f67e9dc0543812b715126dfd6fe4264df527556cb6"}, + {file = "rpds_py-0.16.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a4ebb8b20bd09c5ce7884c8f0388801100f5e75e7f733b1b6613c713371feefc"}, + {file = "rpds_py-0.16.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a3d7e2ea25d3517c6d7e5a1cc3702cffa6bd18d9ef8d08d9af6717fc1c700eed"}, + {file = "rpds_py-0.16.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f28ac0e8e7242d140f99402a903a2c596ab71550272ae9247ad78f9a932b5698"}, + {file = "rpds_py-0.16.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19f00f57fdd38db4bb5ad09f9ead1b535332dbf624200e9029a45f1f35527ebb"}, + {file = "rpds_py-0.16.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3da5a4c56953bdbf6d04447c3410309616c54433146ccdb4a277b9cb499bc10e"}, + {file = "rpds_py-0.16.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ec2e1cf025b2c0f48ec17ff3e642661da7ee332d326f2e6619366ce8e221f018"}, + {file = "rpds_py-0.16.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e0441fb4fdd39a230477b2ca9be90868af64425bfe7b122b57e61e45737a653b"}, + {file = "rpds_py-0.16.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9f0350ef2fba5f34eb0c9000ea328e51b9572b403d2f7f3b19f24085f6f598e8"}, + {file = "rpds_py-0.16.2-cp311-none-win32.whl", hash = "sha256:5a80e2f83391ad0808b4646732af2a7b67550b98f0cae056cb3b40622a83dbb3"}, + {file = "rpds_py-0.16.2-cp311-none-win_amd64.whl", hash = "sha256:e04e56b4ca7a770593633556e8e9e46579d66ec2ada846b401252a2bdcf70a6d"}, + {file = "rpds_py-0.16.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:5e6caa3809e50690bd92fa490f5c38caa86082c8c3315aa438bce43786d5e90d"}, + {file = "rpds_py-0.16.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2e53b9b25cac9065328901713a7e9e3b12e4f57ef4280b370fbbf6fef2052eef"}, + {file = "rpds_py-0.16.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af27423662f32d7501a00c5e7342f7dbd1e4a718aea7a239781357d15d437133"}, + {file = "rpds_py-0.16.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:43d4dd5fb16eb3825742bad8339d454054261ab59fed2fbac84e1d84d5aae7ba"}, + {file = "rpds_py-0.16.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e061de3b745fe611e23cd7318aec2c8b0e4153939c25c9202a5811ca911fd733"}, + {file = "rpds_py-0.16.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b811d182ad17ea294f2ec63c0621e7be92a1141e1012383461872cead87468f"}, + {file = "rpds_py-0.16.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5552f328eaef1a75ff129d4d0c437bf44e43f9436d3996e8eab623ea0f5fcf73"}, + {file = "rpds_py-0.16.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:dcbe1f8dd179e4d69b70b1f1d9bb6fd1e7e1bdc9c9aad345cdeb332e29d40748"}, + {file = "rpds_py-0.16.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8aad80645a011abae487d356e0ceb359f4938dfb6f7bcc410027ed7ae4f7bb8b"}, + {file = "rpds_py-0.16.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b6f5549d6ed1da9bfe3631ca9483ae906f21410be2445b73443fa9f017601c6f"}, + {file = "rpds_py-0.16.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d452817e0d9c749c431a1121d56a777bd7099b720b3d1c820f1725cb40928f58"}, + {file = "rpds_py-0.16.2-cp312-none-win32.whl", hash = "sha256:888a97002e986eca10d8546e3c8b97da1d47ad8b69726dcfeb3e56348ebb28a3"}, + {file = "rpds_py-0.16.2-cp312-none-win_amd64.whl", hash = "sha256:d8dda2a806dfa4a9b795950c4f5cc56d6d6159f7d68080aedaff3bdc9b5032f5"}, + {file = "rpds_py-0.16.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:071980663c273bf3d388fe5c794c547e6f35ba3335477072c713a3176bf14a60"}, + {file = "rpds_py-0.16.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:726ac36e8a3bb8daef2fd482534cabc5e17334052447008405daca7ca04a3108"}, + {file = "rpds_py-0.16.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e9e557db6a177470316c82f023e5d571811c9a4422b5ea084c85da9aa3c035fc"}, + {file = "rpds_py-0.16.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:90123853fc8b1747f80b0d354be3d122b4365a93e50fc3aacc9fb4c2488845d6"}, + {file = "rpds_py-0.16.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a61f659665a39a4d17d699ab3593d7116d66e1e2e3f03ef3fb8f484e91908808"}, + {file = "rpds_py-0.16.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cc97f0640e91d7776530f06e6836c546c1c752a52de158720c4224c9e8053cad"}, + {file = "rpds_py-0.16.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44a54e99a2b9693a37ebf245937fd6e9228b4cbd64b9cc961e1f3391ec6c7391"}, + {file = "rpds_py-0.16.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bd4b677d929cf1f6bac07ad76e0f2d5de367e6373351c01a9c0a39f6b21b4a8b"}, + {file = "rpds_py-0.16.2-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:5ef00873303d678aaf8b0627e111fd434925ca01c657dbb2641410f1cdaef261"}, + {file = "rpds_py-0.16.2-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:349cb40897fd529ca15317c22c0eab67f5ac5178b5bd2c6adc86172045210acc"}, + {file = "rpds_py-0.16.2-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:2ddef620e70eaffebed5932ce754d539c0930f676aae6212f8e16cd9743dd365"}, + {file = "rpds_py-0.16.2-cp38-none-win32.whl", hash = "sha256:882ce6e25e585949c3d9f9abd29202367175e0aab3aba0c58c9abbb37d4982ff"}, + {file = "rpds_py-0.16.2-cp38-none-win_amd64.whl", hash = "sha256:f4bd4578e44f26997e9e56c96dedc5f1af43cc9d16c4daa29c771a00b2a26851"}, + {file = "rpds_py-0.16.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:69ac7ea9897ec201ce68b48582f3eb34a3f9924488a5432a93f177bf76a82a7e"}, + {file = "rpds_py-0.16.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a9880b4656efe36ccad41edc66789e191e5ee19a1ea8811e0aed6f69851a82f4"}, + {file = "rpds_py-0.16.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee94cb58c0ba2c62ee108c2b7c9131b2c66a29e82746e8fa3aa1a1effbd3dcf1"}, + {file = "rpds_py-0.16.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:24f7a2eb3866a9e91f4599851e0c8d39878a470044875c49bd528d2b9b88361c"}, + {file = "rpds_py-0.16.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ca57468da2d9a660bcf8961637c85f2fbb2aa64d9bc3f9484e30c3f9f67b1dd7"}, + {file = "rpds_py-0.16.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ccd4e400309e1f34a5095bf9249d371f0fd60f8a3a5c4a791cad7b99ce1fd38d"}, + {file = "rpds_py-0.16.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80443fe2f7b3ea3934c5d75fb0e04a5dbb4a8e943e5ff2de0dec059202b70a8b"}, + {file = "rpds_py-0.16.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4d6a9f052e72d493efd92a77f861e45bab2f6be63e37fa8ecf0c6fd1a58fedb0"}, + {file = "rpds_py-0.16.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:35953f4f2b3216421af86fd236b7c0c65935936a94ea83ddbd4904ba60757773"}, + {file = "rpds_py-0.16.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:981d135c7cdaf6cd8eadae1c950de43b976de8f09d8e800feed307140d3d6d00"}, + {file = "rpds_py-0.16.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d0dd7ed2f16df2e129496e7fbe59a34bc2d7fc8db443a606644d069eb69cbd45"}, + {file = "rpds_py-0.16.2-cp39-none-win32.whl", hash = "sha256:703d95c75a72e902544fda08e965885525e297578317989fd15a6ce58414b41d"}, + {file = "rpds_py-0.16.2-cp39-none-win_amd64.whl", hash = "sha256:e93ec1b300acf89730cf27975ef574396bc04edecc358e9bd116fb387a123239"}, + {file = "rpds_py-0.16.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:44627b6ca7308680a70766454db5249105fa6344853af6762eaad4158a2feebe"}, + {file = "rpds_py-0.16.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:3f91df8e6dbb7360e176d1affd5fb0246d2b88d16aa5ebc7db94fd66b68b61da"}, + {file = "rpds_py-0.16.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d904c5693e08bad240f16d79305edba78276be87061c872a4a15e2c301fa2c0"}, + {file = "rpds_py-0.16.2-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:290a81cfbe4673285cdf140ec5cd1658ffbf63ab359f2b352ebe172e7cfa5bf0"}, + {file = "rpds_py-0.16.2-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b634c5ec0103c5cbebc24ebac4872b045cccb9456fc59efdcf6fe39775365bd2"}, + {file = "rpds_py-0.16.2-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a297a4d08cc67c7466c873c78039d87840fb50d05473db0ec1b7b03d179bf322"}, + {file = "rpds_py-0.16.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2e75e17bd0bb66ee34a707da677e47c14ee51ccef78ed6a263a4cc965a072a1"}, + {file = "rpds_py-0.16.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f1b9d9260e06ea017feb7172976ab261e011c1dc2f8883c7c274f6b2aabfe01a"}, + {file = "rpds_py-0.16.2-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:162d7cd9cd311c1b0ff1c55a024b8f38bd8aad1876b648821da08adc40e95734"}, + {file = "rpds_py-0.16.2-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:9b32f742ce5b57201305f19c2ef7a184b52f6f9ba6871cc042c2a61f0d6b49b8"}, + {file = "rpds_py-0.16.2-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:ac08472f41ea77cd6a5dae36ae7d4ed3951d6602833af87532b556c1b4601d63"}, + {file = "rpds_py-0.16.2-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:495a14b72bbe217f2695dcd9b5ab14d4f8066a00f5d209ed94f0aca307f85f6e"}, + {file = "rpds_py-0.16.2-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:8d6b6937ae9eac6d6c0ca3c42774d89fa311f55adff3970fb364b34abde6ed3d"}, + {file = "rpds_py-0.16.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a61226465bda9283686db8f17d02569a98e4b13c637be5a26d44aa1f1e361c2"}, + {file = "rpds_py-0.16.2-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5cf6af100ffb5c195beec11ffaa8cf8523057f123afa2944e6571d54da84cdc9"}, + {file = "rpds_py-0.16.2-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6df15846ee3fb2e6397fe25d7ca6624af9f89587f3f259d177b556fed6bebe2c"}, + {file = "rpds_py-0.16.2-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1be2f033df1b8be8c3167ba3c29d5dca425592ee31e35eac52050623afba5772"}, + {file = "rpds_py-0.16.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96f957d6ab25a78b9e7fc9749d754b98eac825a112b4e666525ce89afcbd9ed5"}, + {file = "rpds_py-0.16.2-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:088396c7c70e59872f67462fcac3ecbded5233385797021976a09ebd55961dfe"}, + {file = "rpds_py-0.16.2-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:4c46ad6356e1561f2a54f08367d1d2e70a0a1bb2db2282d2c1972c1d38eafc3b"}, + {file = "rpds_py-0.16.2-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:47713dc4fce213f5c74ca8a1f6a59b622fc1b90868deb8e8e4d993e421b4b39d"}, + {file = "rpds_py-0.16.2-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:f811771019f063bbd0aa7bb72c8a934bc13ebacb4672d712fc1639cfd314cccc"}, + {file = "rpds_py-0.16.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f19afcfc0dd0dca35694df441e9b0f95bc231b512f51bded3c3d8ca32153ec19"}, + {file = "rpds_py-0.16.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:a4b682c5775d6a3d21e314c10124599976809455ee67020e8e72df1769b87bc3"}, + {file = "rpds_py-0.16.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c647ca87fc0ebe808a41de912e9a1bfef9acb85257e5d63691364ac16b81c1f0"}, + {file = "rpds_py-0.16.2-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:302bd4983bbd47063e452c38be66153760112f6d3635c7eeefc094299fa400a9"}, + {file = "rpds_py-0.16.2-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bf721ede3eb7b829e4a9b8142bd55db0bdc82902720548a703f7e601ee13bdc3"}, + {file = "rpds_py-0.16.2-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:358dafc89ce3894c7f486c615ba914609f38277ef67f566abc4c854d23b997fa"}, + {file = "rpds_py-0.16.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cad0f59ee3dc35526039f4bc23642d52d5f6616b5f687d846bfc6d0d6d486db0"}, + {file = "rpds_py-0.16.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cffa76b385dfe1e38527662a302b19ffb0e7f5cf7dd5e89186d2c94a22dd9d0c"}, + {file = "rpds_py-0.16.2-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:83640a5d7cd3bff694747d50436b8b541b5b9b9782b0c8c1688931d6ee1a1f2d"}, + {file = "rpds_py-0.16.2-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:ed99b4f7179d2111702020fd7d156e88acd533f5a7d3971353e568b6051d5c97"}, + {file = "rpds_py-0.16.2-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:4022b9dc620e14f30201a8a73898a873c8e910cb642bcd2f3411123bc527f6ac"}, + {file = "rpds_py-0.16.2.tar.gz", hash = "sha256:781ef8bfc091b19960fc0142a23aedadafa826bc32b433fdfe6fd7f964d7ef44"}, ] [[package]] name = "safetensors" version = "0.4.1" description = "" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3811,6 +3996,7 @@ torch = ["safetensors[numpy]", "torch (>=1.10)"] name = "send2trash" version = "1.8.2" description = "Send file to trash natively under Mac OS X, Windows and Linux" +category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" files = [ @@ -3825,13 +4011,14 @@ win32 = ["pywin32"] [[package]] name = "sentry-sdk" -version = "1.38.0" +version = "1.39.1" description = "Python client for Sentry (https://sentry.io)" +category = "main" optional = false python-versions = "*" files = [ - {file = "sentry-sdk-1.38.0.tar.gz", hash = "sha256:8feab81de6bbf64f53279b085bd3820e3e737403b0a0d9317f73a2c3374ae359"}, - {file = "sentry_sdk-1.38.0-py2.py3-none-any.whl", hash = "sha256:0017fa73b8ae2d4e57fd2522ee3df30453715b29d2692142793ec5d5f90b94a6"}, + {file = "sentry-sdk-1.39.1.tar.gz", hash = "sha256:320a55cdf9da9097a0bead239c35b7e61f53660ef9878861824fd6d9b2eaf3b5"}, + {file = "sentry_sdk-1.39.1-py2.py3-none-any.whl", hash = "sha256:81b5b9ffdd1a374e9eb0c053b5d2012155db9cbe76393a8585677b753bd5fdc1"}, ] [package.dependencies] @@ -3872,6 +4059,7 @@ tornado = ["tornado (>=5)"] name = "setproctitle" version = "1.3.3" description = "A Python module to customize the process title" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3970,13 +4158,14 @@ test = ["pytest"] [[package]] name = "setuptools" -version = "69.0.2" +version = "69.0.3" description = "Easily download, build, install, upgrade, and uninstall Python packages" +category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-69.0.2-py3-none-any.whl", hash = "sha256:1e8fdff6797d3865f37397be788a4e3cba233608e9b509382a2777d25ebde7f2"}, - {file = "setuptools-69.0.2.tar.gz", hash = "sha256:735896e78a4742605974de002ac60562d286fa8051a7e2299445e8e8fbb01aa6"}, + {file = "setuptools-69.0.3-py3-none-any.whl", hash = "sha256:385eb4edd9c9d5c17540511303e39a147ce2fc04bc55289c322b9e5904fe2c05"}, + {file = "setuptools-69.0.3.tar.gz", hash = "sha256:be1af57fc409f93647f2e8e4573a142ed38724b8cdd389706a867bb4efcf1e78"}, ] [package.extras] @@ -3988,6 +4177,7 @@ testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jar name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" +category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -3999,6 +4189,7 @@ files = [ name = "smmap" version = "5.0.1" description = "A pure Python implementation of a sliding window memory map manager" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4010,6 +4201,7 @@ files = [ name = "sniffio" version = "1.3.0" description = "Sniff out which async library your code is running under" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4021,6 +4213,7 @@ files = [ name = "snowballstemmer" version = "2.2.0" description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms." +category = "dev" optional = false python-versions = "*" files = [ @@ -4032,6 +4225,7 @@ files = [ name = "soupsieve" version = "2.5" description = "A modern CSS selector implementation for Beautiful Soup." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -4043,6 +4237,7 @@ files = [ name = "sphinx" version = "5.2.3" description = "Python documentation generator" +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4078,6 +4273,7 @@ test = ["cython", "html5lib", "pytest (>=4.6)", "typed_ast"] name = "sphinx-autobuild" version = "2021.3.14" description = "Rebuild Sphinx documentation on changes, with live-reload in the browser." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4097,6 +4293,7 @@ test = ["pytest", "pytest-cov"] name = "sphinx-basic-ng" version = "1.0.0b2" description = "A modern skeleton for Sphinx themes." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4114,6 +4311,7 @@ docs = ["furo", "ipython", "myst-parser", "sphinx-copybutton", "sphinx-inline-ta name = "sphinxcontrib-applehelp" version = "1.0.4" description = "sphinxcontrib-applehelp is a Sphinx extension which outputs Apple help books" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -4129,6 +4327,7 @@ test = ["pytest"] name = "sphinxcontrib-devhelp" version = "1.0.2" description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp document." +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -4144,6 +4343,7 @@ test = ["pytest"] name = "sphinxcontrib-htmlhelp" version = "2.0.1" description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -4159,6 +4359,7 @@ test = ["html5lib", "pytest"] name = "sphinxcontrib-jsmath" version = "1.0.1" description = "A sphinx extension which renders display math in HTML via JavaScript" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -4173,6 +4374,7 @@ test = ["flake8", "mypy", "pytest"] name = "sphinxcontrib-napoleon" version = "0.7" description = "Sphinx \"napoleon\" extension." +category = "dev" optional = false python-versions = "*" files = [ @@ -4188,6 +4390,7 @@ six = ">=1.5.2" name = "sphinxcontrib-qthelp" version = "1.0.3" description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp document." +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -4203,6 +4406,7 @@ test = ["pytest"] name = "sphinxcontrib-serializinghtml" version = "1.1.5" description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)." +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -4218,6 +4422,7 @@ test = ["pytest"] name = "stack-data" version = "0.6.3" description = "Extract data from python stack frames and tracebacks for informative displays" +category = "dev" optional = false python-versions = "*" files = [ @@ -4237,6 +4442,7 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] name = "sympy" version = "1.12" description = "Computer algebra system (CAS) in Python" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -4251,6 +4457,7 @@ mpmath = ">=0.19" name = "tabulate" version = "0.9.0" description = "Pretty-print tabular data" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4265,6 +4472,7 @@ widechars = ["wcwidth"] name = "tenacity" version = "8.2.3" description = "Retry code until it succeeds" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4279,6 +4487,7 @@ doc = ["reno", "sphinx", "tornado (>=4.5)"] name = "terminado" version = "0.18.0" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -4300,6 +4509,7 @@ typing = ["mypy (>=1.6,<2.0)", "traitlets (>=5.11.1)"] name = "tinycss2" version = "1.2.1" description = "A tiny CSS parser" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4318,6 +4528,7 @@ test = ["flake8", "isort", "pytest"] name = "tokenizers" version = "0.15.0" description = "" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4433,6 +4644,7 @@ testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] name = "tomli" version = "2.0.1" description = "A lil' TOML parser" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4444,6 +4656,7 @@ files = [ name = "tomlkit" version = "0.12.3" description = "Style preserving TOML library" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4453,31 +4666,32 @@ files = [ [[package]] name = "torch" -version = "2.1.1" +version = "2.1.2" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +category = "main" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.1.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:5ebc43f5355a9b7be813392b3fb0133991f0380f6f0fcc8218d5468dc45d1071"}, - {file = "torch-2.1.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:84fefd63356416c0cd20578637ccdbb82164993400ed17b57c951dd6376dcee8"}, - {file = "torch-2.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:0a7a9da0c324409bcb5a7bdad1b4e94e936d21c2590aaa7ac2f63968da8c62f7"}, - {file = "torch-2.1.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:1e1e5faddd43a8f2c0e0e22beacd1e235a2e447794d807483c94a9e31b54a758"}, - {file = "torch-2.1.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:e76bf3c5c354874f1da465c852a2fb60ee6cbce306e935337885760f080f9baa"}, - {file = "torch-2.1.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:98fea993639b0bb432dfceb7b538f07c0f1c33386d63f635219f49254968c80f"}, - {file = "torch-2.1.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:61b51b33c61737c287058b0c3061e6a9d3c363863e4a094f804bc486888a188a"}, - {file = "torch-2.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:1d70920da827e2276bf07f7ec46958621cad18d228c97da8f9c19638474dbd52"}, - {file = "torch-2.1.1-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:a70593806f1d7e6b53657d96810518da0f88ef2608c98a402955765b8c79d52c"}, - {file = "torch-2.1.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:e312f7e82e49565f7667b0bbf9559ab0c597063d93044740781c02acd5a87978"}, - {file = "torch-2.1.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:1e3cbecfa5a7314d828f4a37b0c286714dc9aa2e69beb7a22f7aca76567ed9f4"}, - {file = "torch-2.1.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:9ca0fcbf3d5ba644d6a8572c83a9abbdf5f7ff575bc38529ef6c185a3a71bde9"}, - {file = "torch-2.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:2dc9f312fc1fa0d61a565a0292ad73119d4b74c9f8b5031b55f8b4722abca079"}, - {file = "torch-2.1.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:d56b032176458e2af4709627bbd2c20fe2917eff8cd087a7fe313acccf5ce2f1"}, - {file = "torch-2.1.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:29e3b90a8c281f6660804a939d1f4218604c80162e521e1e6d8c8557325902a0"}, - {file = "torch-2.1.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:bd95cee8511584b67ddc0ba465c3f1edeb5708d833ee02af1206b4486f1d9096"}, - {file = "torch-2.1.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:b31230bd058424e56dba7f899280dbc6ac8b9948e43902e0c84a44666b1ec151"}, - {file = "torch-2.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:403f1095e665e4f35971b43797a920725b8b205723aa68254a4050c6beca29b6"}, - {file = "torch-2.1.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:715b50d8c1de5da5524a68287eb000f73e026e74d5f6b12bc450ef6995fcf5f9"}, - {file = "torch-2.1.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:db67e8725c76f4c7f4f02e7551bb16e81ba1a1912867bc35d7bb96d2be8c78b4"}, + {file = "torch-2.1.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:3a871edd6c02dae77ad810335c0833391c1a4ce49af21ea8cf0f6a5d2096eea8"}, + {file = "torch-2.1.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:bef6996c27d8f6e92ea4e13a772d89611da0e103b48790de78131e308cf73076"}, + {file = "torch-2.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:0e13034fd5fb323cbbc29e56d0637a3791e50dd589616f40c79adfa36a5a35a1"}, + {file = "torch-2.1.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:d9b535cad0df3d13997dbe8bd68ac33e0e3ae5377639c9881948e40794a61403"}, + {file = "torch-2.1.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:f9a55d55af02826ebfbadf4e9b682f0f27766bc33df8236b48d28d705587868f"}, + {file = "torch-2.1.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:a6ebbe517097ef289cc7952783588c72de071d4b15ce0f8b285093f0916b1162"}, + {file = "torch-2.1.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:8f32ce591616a30304f37a7d5ea80b69ca9e1b94bba7f308184bf616fdaea155"}, + {file = "torch-2.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:e0ee6cf90c8970e05760f898d58f9ac65821c37ffe8b04269ec787aa70962b69"}, + {file = "torch-2.1.2-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:76d37967c31c99548ad2c4d3f2cf191db48476f2e69b35a0937137116da356a1"}, + {file = "torch-2.1.2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:e2d83f07b4aac983453ea5bf8f9aa9dacf2278a8d31247f5d9037f37befc60e4"}, + {file = "torch-2.1.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:f41fe0c7ecbf903a568c73486139a75cfab287a0f6c17ed0698fdea7a1e8641d"}, + {file = "torch-2.1.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e3225f47d50bb66f756fe9196a768055d1c26b02154eb1f770ce47a2578d3aa7"}, + {file = "torch-2.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:33d59cd03cb60106857f6c26b36457793637512998666ee3ce17311f217afe2b"}, + {file = "torch-2.1.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:8e221deccd0def6c2badff6be403e0c53491805ed9915e2c029adbcdb87ab6b5"}, + {file = "torch-2.1.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:05b18594f60a911a0c4f023f38a8bda77131fba5fd741bda626e97dcf5a3dd0a"}, + {file = "torch-2.1.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9ca96253b761e9aaf8e06fb30a66ee301aecbf15bb5a303097de1969077620b6"}, + {file = "torch-2.1.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d93ba70f67b08c2ae5598ee711cbc546a1bc8102cef938904b8c85c2089a51a0"}, + {file = "torch-2.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:255b50bc0608db177e6a3cc118961d77de7e5105f07816585fa6f191f33a9ff3"}, + {file = "torch-2.1.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6984cd5057c0c977b3c9757254e989d3f1124f4ce9d07caa6cb637783c71d42a"}, + {file = "torch-2.1.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:bc195d7927feabc0eb7c110e457c955ed2ab616f3c7c28439dd4188cf589699f"}, ] [package.dependencies] @@ -4508,6 +4722,7 @@ opt-einsum = ["opt-einsum (>=3.3)"] name = "tornado" version = "6.4" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." +category = "dev" optional = false python-versions = ">= 3.8" files = [ @@ -4528,6 +4743,7 @@ files = [ name = "tqdm" version = "4.66.1" description = "Fast, Extensible Progress Meter" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4546,13 +4762,14 @@ telegram = ["requests"] [[package]] name = "traitlets" -version = "5.14.0" +version = "5.14.1" description = "Traitlets Python configuration system" +category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "traitlets-5.14.0-py3-none-any.whl", hash = "sha256:f14949d23829023013c47df20b4a76ccd1a85effb786dc060f34de7948361b33"}, - {file = "traitlets-5.14.0.tar.gz", hash = "sha256:fcdaa8ac49c04dfa0ed3ee3384ef6dfdb5d6f3741502be247279407679296772"}, + {file = "traitlets-5.14.1-py3-none-any.whl", hash = "sha256:2e5a030e6eff91737c643231bfcf04a65b0132078dad75e4936700b213652e74"}, + {file = "traitlets-5.14.1.tar.gz", hash = "sha256:8585105b371a04b8316a43d5ce29c098575c2e477850b62b848b964f1444527e"}, ] [package.extras] @@ -4561,18 +4778,19 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0, [[package]] name = "transformers" -version = "4.35.2" +version = "4.36.2" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" +category = "main" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.35.2-py3-none-any.whl", hash = "sha256:9dfa76f8692379544ead84d98f537be01cd1070de75c74efb13abcbc938fbe2f"}, - {file = "transformers-4.35.2.tar.gz", hash = "sha256:2d125e197d77b0cdb6c9201df9fa7e2101493272e448b9fba9341c695bee2f52"}, + {file = "transformers-4.36.2-py3-none-any.whl", hash = "sha256:462066c4f74ee52516f12890dcc9ec71d1a5e97998db621668455117a54330f6"}, + {file = "transformers-4.36.2.tar.gz", hash = "sha256:d8068e897e47793281501e547d2bbdfc5b8556409c2cb6c3d9e2ca77d4c0b4ec"}, ] [package.dependencies] filelock = "*" -huggingface-hub = ">=0.16.4,<1.0" +huggingface-hub = ">=0.19.3,<1.0" numpy = ">=1.17" packaging = ">=20.0" pyyaml = ">=5.1" @@ -4583,30 +4801,30 @@ tokenizers = ">=0.14,<0.19" tqdm = ">=4.27" [package.extras] -accelerate = ["accelerate (>=0.20.3)"] -agents = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.10,!=1.12.0)"] -all = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] +accelerate = ["accelerate (>=0.21.0)"] +agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.10,!=1.12.0)"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] codecarbon = ["codecarbon (==1.2.0)"] -deepspeed = ["accelerate (>=0.20.3)", "deepspeed (>=0.9.3)"] -deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.19)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -docs = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] +deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.19)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +docs = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] docs-specific = ["hf-doc-builder"] flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] ftfy = ["ftfy"] -integrations = ["optuna", "ray[tune]", "sigopt"] +integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] modelcreation = ["cookiecutter (==1.7.3)"] natten = ["natten (>=0.14.6)"] onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] optuna = ["optuna"] -quality = ["GitPython (<3.1.19)", "black (>=23.1,<24.0)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (>=0.0.241,<=0.0.259)", "urllib3 (<2.0.0)"] -ray = ["ray[tune]"] +quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<2.0.0)"] +ray = ["ray[tune] (>=2.7.0)"] retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] sagemaker = ["sagemaker (>=2.31.0)"] sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] @@ -4614,23 +4832,24 @@ serving = ["fastapi", "pydantic (<2)", "starlette", "uvicorn"] sigopt = ["sigopt"] sklearn = ["scikit-learn"] speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -testing = ["GitPython (<3.1.19)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "tensorboard", "timeout-decorator"] -tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx"] -tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "tensorboard", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] timm = ["timm"] tokenizers = ["tokenizers (>=0.14,<0.19)"] -torch = ["accelerate (>=0.20.3)", "torch (>=1.10,!=1.12.0)"] +torch = ["accelerate (>=0.21.0)", "torch (>=1.10,!=1.12.0)"] torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -torch-vision = ["Pillow (<10.0.0)", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.16.4,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "tqdm (>=4.27)"] +torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] +torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "tqdm (>=4.27)"] video = ["av (==9.2.0)", "decord (==0.6.0)"] -vision = ["Pillow (<10.0.0)"] +vision = ["Pillow (>=10.0.1,<=15.0)"] [[package]] name = "triton" version = "2.1.0" description = "A language and compiler for custom Deep Learning operations" +category = "main" optional = false python-versions = "*" files = [ @@ -4656,6 +4875,7 @@ tutorials = ["matplotlib", "pandas", "tabulate"] name = "typeguard" version = "4.1.5" description = "Run-time type checker for Python" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -4675,6 +4895,7 @@ test = ["coverage[toml] (>=7)", "mypy (>=1.2.0)", "pytest (>=7)"] name = "typer" version = "0.9.0" description = "Typer, build great CLIs. Easy to code. Based on Python type hints." +category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4694,30 +4915,33 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6. [[package]] name = "types-python-dateutil" -version = "2.8.19.14" +version = "2.8.19.20240106" description = "Typing stubs for python-dateutil" +category = "dev" optional = false -python-versions = "*" +python-versions = ">=3.8" files = [ - {file = "types-python-dateutil-2.8.19.14.tar.gz", hash = "sha256:1f4f10ac98bb8b16ade9dbee3518d9ace017821d94b057a425b069f834737f4b"}, - {file = "types_python_dateutil-2.8.19.14-py3-none-any.whl", hash = "sha256:f977b8de27787639986b4e28963263fd0e5158942b3ecef91b9335c130cb1ce9"}, + {file = "types-python-dateutil-2.8.19.20240106.tar.gz", hash = "sha256:1f8db221c3b98e6ca02ea83a58371b22c374f42ae5bbdf186db9c9a76581459f"}, + {file = "types_python_dateutil-2.8.19.20240106-py3-none-any.whl", hash = "sha256:efbbdc54590d0f16152fa103c9879c7d4a00e82078f6e2cf01769042165acaa2"}, ] [[package]] name = "typing-extensions" -version = "4.8.0" +version = "4.9.0" description = "Backported and Experimental Type Hints for Python 3.8+" +category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.8.0-py3-none-any.whl", hash = "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0"}, - {file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"}, + {file = "typing_extensions-4.9.0-py3-none-any.whl", hash = "sha256:af72aea155e91adfc61c3ae9e0e342dbc0cba726d6cba4b6c72c1f34e47291cd"}, + {file = "typing_extensions-4.9.0.tar.gz", hash = "sha256:23478f88c37f27d76ac8aee6c905017a143b0b1b886c3c9f66bc2fd94f9f5783"}, ] [[package]] name = "typing-inspect" version = "0.9.0" description = "Runtime inspection utilities for typing module." +category = "dev" optional = false python-versions = "*" files = [ @@ -4731,19 +4955,21 @@ typing-extensions = ">=3.7.4" [[package]] name = "tzdata" -version = "2023.3" +version = "2023.4" description = "Provider of IANA time zone data" +category = "main" optional = false python-versions = ">=2" files = [ - {file = "tzdata-2023.3-py2.py3-none-any.whl", hash = "sha256:7e65763eef3120314099b6939b5546db7adce1e7d6f2e179e3df563c70511eda"}, - {file = "tzdata-2023.3.tar.gz", hash = "sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a"}, + {file = "tzdata-2023.4-py2.py3-none-any.whl", hash = "sha256:aa3ace4329eeacda5b7beb7ea08ece826c28d761cda36e747cfbf97996d39bf3"}, + {file = "tzdata-2023.4.tar.gz", hash = "sha256:dd54c94f294765522c77399649b4fefd95522479a664a0cec87f41bebc6148c9"}, ] [[package]] name = "uri-template" version = "1.3.0" description = "RFC 6570 URI Template Processor" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4758,6 +4984,7 @@ dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake name = "urllib3" version = "2.1.0" description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -4772,13 +4999,14 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "wandb" -version = "0.16.1" +version = "0.16.2" description = "A CLI and library for interacting with the Weights & Biases API." +category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "wandb-0.16.1-py3-none-any.whl", hash = "sha256:1d7423f92520984585bae9693bb637ae08d3e0c1d75ad4b34215bc44431f114c"}, - {file = "wandb-0.16.1.tar.gz", hash = "sha256:ffe6e8dd8cc8fcd72010c1246fb3d6d226b37c4f111f3f94308a1c0ae28a2fec"}, + {file = "wandb-0.16.2-py3-none-any.whl", hash = "sha256:6b119cf3c01f35e7276b62d052128e5320621d182c9eb5796a12cf62a9b3134f"}, + {file = "wandb-0.16.2.tar.gz", hash = "sha256:e40cd79ea6272fe4762a80b9f47b172e141daeb3b56eb9d1e192ebd10752e64e"}, ] [package.dependencies] @@ -4803,30 +5031,31 @@ typing-extensions = {version = "*", markers = "python_version < \"3.10\""} async = ["httpx (>=0.23.0)"] aws = ["boto3"] azure = ["azure-identity", "azure-storage-blob"] -core = ["wandb-core (>=0.17.0b2)"] gcp = ["google-cloud-storage"] kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"] -launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "typing-extensions"] -media = ["bokeh", "moviepy", "numpy", "pillow", "plotly", "rdkit-pypi", "soundfile"] +launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "typing-extensions"] +media = ["bokeh", "moviepy", "numpy", "pillow", "plotly (>=5.18.0)", "rdkit-pypi", "soundfile"] models = ["cloudpickle"] perf = ["orjson"] sweeps = ["sweeps (>=0.2.0)"] [[package]] name = "wcwidth" -version = "0.2.12" +version = "0.2.13" description = "Measures the displayed width of unicode strings in a terminal" +category = "dev" optional = false python-versions = "*" files = [ - {file = "wcwidth-0.2.12-py2.py3-none-any.whl", hash = "sha256:f26ec43d96c8cbfed76a5075dac87680124fa84e0855195a6184da9c187f133c"}, - {file = "wcwidth-0.2.12.tar.gz", hash = "sha256:f01c104efdf57971bcb756f054dd58ddec5204dd15fa31d6503ea57947d97c02"}, + {file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"}, + {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, ] [[package]] name = "webcolors" version = "1.13" description = "A library for working with the color formats defined by HTML and CSS." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4842,6 +5071,7 @@ tests = ["pytest", "pytest-cov"] name = "webencodings" version = "0.5.1" description = "Character encoding aliases for legacy web content" +category = "dev" optional = false python-versions = "*" files = [ @@ -4853,6 +5083,7 @@ files = [ name = "websocket-client" version = "1.7.0" description = "WebSocket client for Python with low level API options" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -4869,6 +5100,7 @@ test = ["websockets"] name = "widgetsnbextension" version = "4.0.9" description = "Jupyter interactive widgets for Jupyter Notebook" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4880,6 +5112,7 @@ files = [ name = "xxhash" version = "3.4.1" description = "Python binding for xxHash" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4997,6 +5230,7 @@ files = [ name = "yarl" version = "1.9.4" description = "Yet another URL library" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5100,6 +5334,7 @@ multidict = ">=4.0" name = "zipp" version = "3.17.0" description = "Backport of pathlib-compatible object wrapper for zip files" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -5113,5 +5348,5 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" -python-versions = ">=3.8,<4.0" -content-hash = "cf88ba97e4847d4220e2fb639a587d62aa5a98e36fbfc632d7e3914cd08dcebb" + python-versions = ">=3.8,<4.0" +content-hash = "44b4da5ea68927793614a3b2f05fb9ead790d8a4b506c240b17d19b20cbe7cee" diff --git a/pyproject.toml b/pyproject.toml index fe494785d..f43a10572 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,9 +29,10 @@ rich=">=12.6.0" torch=">=1.10,!=2.0,!=2.1.0" # Pin >=2.1.1 due to known MPS errors on 2.1.0 tqdm=">=4.64.1" - transformers=">=4.25.1" + transformers=">=4.34" typing-extensions="*" wandb=">=0.13.5" + better-abc = "^0.0.3" [tool.poetry.group] [tool.poetry.group.dev.dependencies] diff --git a/tests/unit/test_grouped_query_attention.py b/tests/unit/test_grouped_query_attention.py new file mode 100644 index 000000000..885ec39a0 --- /dev/null +++ b/tests/unit/test_grouped_query_attention.py @@ -0,0 +1,82 @@ +import torch + +from transformer_lens.components import Attention, GroupedQueryAttention +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def test_grouped_query_attention_output_is_correct(): + """Verifies that grouped query attention (GPA) block behaves correctly - see https://arxiv.org/abs/2305.13245v2 for details on GPA. + A GPA block with h query heads, n key-value heads, key parameters _K and value parameters _V should have the same output as a regular attention block + with h heads, whose parameters K and V are _K and _V repeated h/n times respectively. This test uses torch.repeat_interleave, which is also used by + the GPA block internally, to generate K and V from _K and _V""" + d_model = 512 + d_head = 32 + n_heads = 16 + n_ctx = 128 + n_key_value_heads = 4 + n_layers = 1 + + cfg = HookedTransformerConfig( + d_model=d_model, + d_head=d_head, + n_heads=n_heads, + n_ctx=n_ctx, + n_key_value_heads=n_key_value_heads, + n_layers=n_layers, + act_fn="silu", + ) + + regular_attention = Attention(cfg) + grouped_query_attention = GroupedQueryAttention(cfg) + + W_Q = torch.rand((n_heads, d_model, d_head)) + b_Q = torch.rand((n_heads, d_head)) + _W_K = torch.rand((n_key_value_heads, d_model, d_head)) + W_K = torch.repeat_interleave(_W_K, dim=0, repeats=n_heads // n_key_value_heads) + _b_K = torch.rand((n_key_value_heads, d_head)) + b_K = torch.repeat_interleave(_b_K, dim=0, repeats=n_heads // n_key_value_heads) + _W_V = torch.rand((n_key_value_heads, d_model, d_head)) + W_V = torch.repeat_interleave(_W_V, dim=0, repeats=n_heads // n_key_value_heads) + _b_V = torch.rand((n_key_value_heads, d_head)) + b_V = torch.repeat_interleave(_b_V, dim=0, repeats=n_heads // n_key_value_heads) + W_O = torch.rand((n_heads, d_head, d_model)) + b_O = torch.rand(d_model) + + regular_attention_state_dict = { + "W_Q": W_Q, + "b_Q": b_Q, + "W_O": W_O, + "b_O": b_O, + "W_K": W_K, + "b_K": b_K, + "W_V": W_V, + "b_V": b_V, + "mask": regular_attention.state_dict()["mask"], + "IGNORE": regular_attention.state_dict()["IGNORE"], + } + grouped_query_attemtion_state_dict = { + "W_Q": W_Q, + "b_Q": b_Q, + "W_O": W_O, + "b_O": b_O, + "_W_K": _W_K, + "_b_K": _b_K, + "_W_V": _W_V, + "_b_V": _b_V, + "mask": grouped_query_attention.state_dict()["mask"], + "IGNORE": grouped_query_attention.state_dict()["IGNORE"], + } + + regular_attention.load_state_dict(regular_attention_state_dict) + grouped_query_attention.load_state_dict(grouped_query_attemtion_state_dict) + + query_input = torch.rand((1, 5, d_model)) + key_input = torch.rand((1, 5, d_model)) + value_input = torch.rand((1, 5, d_model)) + + regular_attn_output = regular_attention(query_input, key_input, value_input) + grouped_query_attn_output = grouped_query_attention( + query_input, key_input, value_input + ) + + assert torch.equal(regular_attn_output, grouped_query_attn_output) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 88c53ad7e..60379e453 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -313,7 +313,10 @@ def input_to_embed( d_head_in_cache, ) = past_kv_cache[0].past_keys.shape assert cached_batch_size == batch_size - assert num_heads_in_cache == self.cfg.n_heads + if self.cfg.n_key_value_heads is None: + assert num_heads_in_cache == self.cfg.n_heads + else: + assert num_heads_in_cache == self.cfg.n_key_value_heads assert d_head_in_cache == self.cfg.d_head pos_offset = cache_ctx_length if self.cfg.use_hook_tokens: @@ -1651,7 +1654,13 @@ def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]): """ for layer in range(self.cfg.n_layers): # shape [head_index, d_head] - b_V = state_dict[f"blocks.{layer}.attn.b_V"] + if self.cfg.n_key_value_heads is None: + b_V = state_dict[f"blocks.{layer}.attn.b_V"] + else: + b_V = state_dict[f"blocks.{layer}.attn._b_V"] + b_V = torch.repeat_interleave( + b_V, dim=0, repeats=self.cfg.n_heads // self.cfg.n_key_value_heads + ) # [head_index, d_head, d_model] W_O = state_dict[f"blocks.{layer}.attn.W_O"] # [d_model] @@ -1659,7 +1668,12 @@ def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]): folded_b_O = b_O_original + (b_V[:, :, None] * W_O).sum([0, 1]) state_dict[f"blocks.{layer}.attn.b_O"] = folded_b_O - state_dict[f"blocks.{layer}.attn.b_V"] = torch.zeros_like(b_V) + if self.cfg.n_key_value_heads is None: + state_dict[f"blocks.{layer}.attn.b_V"] = torch.zeros_like(b_V) + else: + state_dict[f"blocks.{layer}.attn._b_V"] = torch.zeros_like( + state_dict[f"blocks.{layer}.attn._b_V"] + ) return state_dict def refactor_factored_attn_matrices(self, state_dict: Dict[str, torch.Tensor]): diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index d54b785d7..501f6e881 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -147,6 +147,8 @@ class HookedTransformerConfig: tokenizer_prepends_bos (bool, *optional*): This flag is set by set_tokenizer. It is set to True only when the tokenizer automatically prepends the BOS token if initialized with add_bos_token=True. We need this information to dynamically control bos prepending. + n_key_value_heads (int, *optional*): The number of groups of heads that use the same key and value matrix. + Only for models that use Grouped Query Attention. post_embedding_ln (bool): Whether to apply layer normalization after embedding the tokens. Defaults to False. """ @@ -196,6 +198,7 @@ class HookedTransformerConfig: default_prepend_bos: bool = True dtype: torch.dtype = torch.float32 tokenizer_prepends_bos: Optional[bool] = None + n_key_value_heads: Optional[int] = None post_embedding_ln: bool = False rotary_base: int = 10000 trust_remote_code: bool = False diff --git a/transformer_lens/components.py b/transformer_lens/components.py index b8926cb43..942ec2819 100644 --- a/transformer_lens/components.py +++ b/transformer_lens/components.py @@ -5,6 +5,7 @@ :class:`transformer_lens.HookedTransformer`. """ import logging +from abc import ABC from typing import Dict, Optional, Tuple, Union import einops @@ -12,6 +13,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from better_abc import abstract_attribute from fancy_einsum import einsum from jaxtyping import Float, Int @@ -381,17 +383,17 @@ def forward( return x * self.w -# Attention -class Attention(nn.Module): +class AbstractAttention(ABC, nn.Module): def __init__( self, cfg: Union[Dict, HookedTransformerConfig], attn_type: str = "global", layer_id: Optional[int] = None, ): - """Attention Block - params have shape [head_index, d_model, d_head] (or [head_index, d_head, d_model] for W_O) and multiply on the right. attn_scores refers to query key dot product immediately before attention softmax + """Abstract Base Class of Attention Blocks, featuring common functionality of both Attention and GroupedQueryAttention blocks. - Convention: All attention pattern-style matrices have shape [batch, head_index, query_pos, key_pos] + Query and Output projections are defined in this class as they are the same for regular and grouped query attention. + Attributes related to Key and Value projections are abstract as their implementations may differ. Args: cfg (Union[Dict, HookedTransformerConfig]): Config @@ -407,16 +409,8 @@ def __init__( self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype ) ) - self.W_K = nn.Parameter( - torch.empty( - self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype - ) - ) - self.W_V = nn.Parameter( - torch.empty( - self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype - ) - ) + self.W_K = abstract_attribute() + self.W_V = abstract_attribute() self.W_O = nn.Parameter( torch.empty( self.cfg.n_heads, self.cfg.d_head, self.cfg.d_model, dtype=cfg.dtype @@ -425,12 +419,8 @@ def __init__( self.b_Q = nn.Parameter( torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype) ) - self.b_K = nn.Parameter( - torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype) - ) - self.b_V = nn.Parameter( - torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype) - ) + self.b_K = abstract_attribute() + self.b_V = abstract_attribute() self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype)) self.attn_type = attn_type @@ -540,37 +530,7 @@ def forward( attention_mask is the attention mask for padded tokens. Defaults to None. """ - if self.cfg.use_split_qkv_input or self.cfg.use_attn_in: - qkv_einops_string = "batch pos head_index d_model" - else: - qkv_einops_string = "batch pos d_model" - q = self.hook_q( - einsum( - f"{qkv_einops_string}, head_index d_model d_head \ - -> batch pos head_index d_head", - query_input, - self.W_Q, - ) - + self.b_Q - ) # [batch, pos, head_index, d_head] - k = self.hook_k( - einsum( - f"{qkv_einops_string}, head_index d_model d_head \ - -> batch pos head_index d_head", - key_input, - self.W_K, - ) - + self.b_K - ) # [batch, pos, head_index, d_head] - v = self.hook_v( - einsum( - f"{qkv_einops_string}, head_index d_model d_head \ - -> batch pos head_index d_head", - value_input, - self.W_V, - ) - + self.b_V - ) # [batch, pos, head_index, d_head] + q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input) if past_kv_cache_entry is not None: # Appends the new keys and values to the cached values, and automatically updates the cache @@ -593,15 +553,8 @@ def forward( q = q.to(torch.float32) k = k.to(torch.float32) - attn_scores = ( - einsum( - "batch query_pos head_index d_head, \ - batch key_pos head_index d_head \ - -> batch head_index query_pos key_pos", - q, - k, - ) - / self.attn_scale + attn_scores = self.calculate_attention_scores( + q, k ) # [batch, head_index, query_pos, key_pos] if self.cfg.positional_embedding_type == "alibi": @@ -632,15 +585,7 @@ def forward( pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] pattern = pattern.to(self.cfg.dtype) - z = self.hook_z( - einsum( - "batch key_pos head_index d_head, \ - batch head_index query_pos key_pos -> \ - batch query_pos head_index d_head", - v, - pattern, - ) - ) # [batch, pos, head_index, d_head] + z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] if not self.cfg.use_attn_result: out = ( ( @@ -674,6 +619,92 @@ def forward( ) # [batch, pos, d_model] return out + def calculate_qkv_matrices( + self, + query_input: Union[ + Float[torch.Tensor, "batch pos d_model"], + Float[torch.Tensor, "batch pos head_index d_model"], + ], + key_input: Union[ + Float[torch.Tensor, "batch pos d_model"], + Float[torch.Tensor, "batch pos head_index d_model"], + ], + value_input: Union[ + Float[torch.Tensor, "batch pos d_model"], + Float[torch.Tensor, "batch pos head_index d_model"], + ], + ) -> Tuple[ + Float[torch.Tensor, "batch pos head_index d_head"], + Float[torch.Tensor, "batch pos head_index d_head"], + Float[torch.Tensor, "batch pos head_index d_head"], + ]: + if self.cfg.use_split_qkv_input or self.cfg.use_attn_in: + qkv_einops_string = "batch pos head_index d_model" + else: + qkv_einops_string = "batch pos d_model" + + q = self.hook_q( + einsum( + f"{qkv_einops_string}, head_index d_model d_head \ + -> batch pos head_index d_head", + query_input, + self.W_Q, + ) + + self.b_Q + ) # [batch, pos, head_index, d_head] + k = self.hook_k( + einsum( + f"{qkv_einops_string}, head_index d_model d_head \ + -> batch pos head_index d_head", + key_input, + self.W_K, + ) + + self.b_K + ) # [batch, pos, head_index, d_head] + v = self.hook_v( + einsum( + f"{qkv_einops_string}, head_index d_model d_head \ + -> batch pos head_index d_head", + value_input, + self.W_V, + ) + + self.b_V + ) # [batch, pos, head_index, d_head] + return q, k, v + + def calculate_attention_scores( + self, + q: Float[torch.Tensor, "batch query_pos head_index d_head"], + k: Float[torch.Tensor, "batch key_pos head_index d_head"], + ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]: + attn_scores = ( + einsum( + "batch query_pos head_index d_head, \ + batch key_pos head_index d_head \ + -> batch head_index query_pos key_pos", + q, + k, + ) + / self.attn_scale + ) + return attn_scores + + def calculate_z_scores( + self, + v: Float[torch.Tensor, "batch key_pos head_index d_head"], + pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"], + ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]: + z = self.hook_z( + einsum( + "batch key_pos head_index d_head, \ + batch head_index query_pos key_pos -> \ + batch query_pos head_index d_head", + v, + pattern, + ) + ) + return z + def apply_causal_mask( self, attn_scores: Float[ @@ -914,6 +945,225 @@ def create_alibi_bias( return alibi_bias +# Attention +class Attention(AbstractAttention): + def __init__( + self, + cfg: Union[Dict, HookedTransformerConfig], + attn_type: str = "global", + layer_id: Optional[int] = None, + ): + """Attention Block - params have shape [head_index, d_model, d_head] (or [head_index, d_head, d_model] for W_O) and multiply on the right. attn_scores refers to query key dot product immediately before attention softmax + + Convention: All attention pattern-style matrices have shape [batch, head_index, query_pos, key_pos] + + Args: + cfg (Union[Dict, HookedTransformerConfig]): Config + attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global". + layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. + """ + super().__init__(cfg, attn_type, layer_id) + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.W_K = nn.Parameter( + torch.empty( + self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype + ) + ) + self.W_V = nn.Parameter( + torch.empty( + self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype + ) + ) + self.b_K = nn.Parameter( + torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype) + ) + self.b_V = nn.Parameter( + torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype) + ) + + +class GroupedQueryAttention(AbstractAttention): + def __init__( + self, + cfg: Union[Dict, HookedTransformerConfig], + attn_type: str = "global", + layer_id: Union[int, None] = None, + ): + """Grouped Query Attention Block - see https://arxiv.org/abs/2305.13245v2 for details. + Similar to regular attention, W_Q, W_K, and W_V all have shape [head_index, d_model, d_head] and W_Q has shape [head_index, d_head, d_model]. + However, under the hood the key and value weights _W_K and _W_V are stored with shape [n_key_value_heads, d_model, d_head] and are expanded when the corresponding properties' getter is called. + Similarly, during a forward pass, initially K and V are kept in shapes [batch, pos, n_key_value_heads, d_head] and will only be expanded to shapes [batch, pos, n_heads, d_head] + using torch.repeat_interleave when the attention pattern and z-scores are calculated. + + Args: + cfg (Union[Dict, HookedTransformerConfig]): Config + attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global". + layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. + """ + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + assert cfg.n_key_value_heads is not None + super().__init__(cfg, attn_type, layer_id) + self.repeat_kv_heads = cfg.n_heads // cfg.n_key_value_heads + self._W_K = nn.Parameter( + torch.empty( + cfg.n_key_value_heads, + self.cfg.d_model, + self.cfg.d_head, + dtype=cfg.dtype, + ) + ) + self._W_V = nn.Parameter( + torch.empty( + cfg.n_key_value_heads, + self.cfg.d_model, + self.cfg.d_head, + dtype=cfg.dtype, + ) + ) + self._b_K = nn.Parameter( + torch.zeros(cfg.n_key_value_heads, self.cfg.d_head, dtype=cfg.dtype) + ) + self._b_V = nn.Parameter( + torch.zeros(cfg.n_key_value_heads, self.cfg.d_head, dtype=cfg.dtype) + ) + + @property + def W_K(self): + return torch.repeat_interleave(self._W_K, dim=0, repeats=self.repeat_kv_heads) + + @W_K.setter + def W_K(self, value): + self._W_K = value + + @property + def W_V(self): + return torch.repeat_interleave(self._W_V, dim=0, repeats=self.repeat_kv_heads) + + @W_V.setter + def W_V(self, value): + self._W_V = value + + @property + def b_K(self): + return torch.repeat_interleave(self._b_K, dim=0, repeats=self.repeat_kv_heads) + + @b_K.setter + def b_K(self, value): + self._b_K = value + + @property + def b_V(self): + return torch.repeat_interleave(self._b_V, dim=0, repeats=self.repeat_kv_heads) + + @b_V.setter + def b_V(self, value): + self._b_V = value + + def calculate_qkv_matrices( + self, + query_input: Union[ + Float[torch.Tensor, "batch pos d_model"], + Float[torch.Tensor, "batch pos head_index d_model"], + ], + key_input: Union[ + Float[torch.Tensor, "batch pos d_model"], + Float[torch.Tensor, "batch pos kv_head_index d_model"], + ], + value_input: Union[ + Float[torch.Tensor, "batch pos d_model"], + Float[torch.Tensor, "batch pos kv_head_index d_model"], + ], + ) -> Tuple[ + Float[torch.Tensor, "batch pos head_index d_head"], + Float[torch.Tensor, "batch pos kv_head_index d_head"], + Float[torch.Tensor, "batch pos kv_head_index d_head"], + ]: + """Calculate the Q, K, and V matrices for grouped query attention. + This function uses the unexpanded weights _W_K and _W_V to calculate K and V. + + Args: + query_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos head_index d_model"]]): The input tensor for the query projection. + key_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos kv_head_index d_model"]]): The input tensor for the key projection. Note that is has as many head dimensions as the GPA block has key-value heads. + value_input (Union[Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos kv_head_index d_model"]]): The input tensor for the value projection. Note that is has as many head dimensions as the GPA block has key-value heads. + + Returns: + Tuple[Float[torch.Tensor, "batch pos head_index d_head"], Float[torch.Tensor, "batch pos kv_head_index d_head"], Float[torch.Tensor, "batch pos kv_head_index d_head"]]: + A tuple containing the Q, K, and V matrices with the specified shapes. + """ + if self.cfg.use_split_qkv_input or self.cfg.use_attn_in: + qkv_einops_string = "batch pos kv_head_index d_model" + else: + qkv_einops_string = "batch pos d_model" + + q = self.hook_q( + einsum( + f"{qkv_einops_string}, head_index d_model d_head \ + -> batch pos head_index d_head", + query_input, + self.W_Q, + ) + + self.b_Q + ) # [batch, pos, head_index, d_head] + k = self.hook_k( + einsum( + f"{qkv_einops_string}, kv_head_index d_model d_head \ + -> batch pos kv_head_index d_head", + key_input, + self._W_K, + ) + + self._b_K + ) # [batch, pos, head_index, d_head] + v = self.hook_v( + einsum( + f"{qkv_einops_string}, kv_head_index d_model d_head \ + -> batch pos kv_head_index d_head", + value_input, + self._W_V, + ) + + self._b_V + ) # [batch, pos, head_index, d_head] + return q, k, v + + def calculate_attention_scores( + self, + q: Float[torch.Tensor, "batch query_pos head_index d_head"], + k: Float[torch.Tensor, "batch key_pos kv_head_index d_head"], + ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]: + """Calculate attention scores from Q and the unexpanded K matrix. + K will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave. + + Args: + q (Float[torch.Tensor, "batch query_pos head_index d_head"]): The Q tensor. + k (Float[torch.Tensor, "batch key_pos kv_head_index d_head"]): The K tensor. + + Returns: + Float[torch.Tensor, "batch head_index query_pos key_pos"]: The attention scores. + """ + k = torch.repeat_interleave(k, dim=2, repeats=self.repeat_kv_heads) + return super().calculate_attention_scores(q, k) + + def calculate_z_scores( + self, + v: Float[torch.Tensor, "batch key_pos kv_head_index d_head"], + pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"], + ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]: + """Calculate z scores from the attention pattern and the unexpanded V matrix. + V will be expaned from [batch, pos, n_key_value_head, d_head] to [batch, pos, n_query_heads, d_head] using torch.repeat_interleave. + + Args: + v (Float[torch.Tensor, "batch query_pos head_index d_head"]): The V tensor. + pattern (Float[torch.Tensor, "batch key_pos kv_head_index d_head"]): The attention pattern. + + Returns: + Float[torch.Tensor, "batch head_index query_pos key_pos"]: The z scores. + """ + v = torch.repeat_interleave(v, dim=2, repeats=self.repeat_kv_heads) + return super().calculate_z_scores(v, pattern) + + # MLP Layers class MLP(nn.Module): def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): @@ -1101,12 +1351,15 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig], block_index): f"Invalid normalization_type passed in {self.cfg.normalization_type}" ) + attention = ( + Attention if self.cfg.n_key_value_heads is None else GroupedQueryAttention + ) if not self.cfg.use_local_attn: - self.attn = Attention(cfg, "global", block_index) + self.attn = attention(cfg, "global", block_index) else: assert self.cfg.attn_types is not None attn_type = self.cfg.attn_types[block_index] - self.attn = Attention(cfg, attn_type, block_index) + self.attn = attention(cfg, attn_type, block_index) if not self.cfg.attn_only: if self.cfg.gated_mlp: self.mlp = GatedMLP(cfg) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 10345a069..c75ac3748 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -139,6 +139,8 @@ "stabilityai/stablelm-base-alpha-7b", "stabilityai/stablelm-tuned-alpha-3b", "stabilityai/stablelm-tuned-alpha-7b", + "mistralai/Mistral-7B-v0.1", + "mistralai/Mistral-7B-Instruct-v0.1", "bigscience/bloom-560m", "bigscience/bloom-1b1", "bigscience/bloom-1b7", @@ -524,6 +526,8 @@ "stablelm-tuned-alpha-7b", "stablelm-tuned-7b", ], + "mistralai/Mistral-7B-v0.1": ["mistral-7b"], + "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"], "bigscience/bloom-560m": ["bloom-560m"], "bigscience/bloom-1b1": ["bloom-1b1"], "bigscience/bloom-1b7": ["bloom-1b7"], @@ -594,11 +598,13 @@ def convert_hf_model_config(model_name: str, **kwargs): # In case the user passed in an alias official_model_name = get_official_model_name(model_name) # Load HuggingFace model config - if "llama" not in official_model_name.lower(): + if "llama" in official_model_name.lower(): + architecture = "LlamaForCausalLM" + elif "mistral" in official_model_name.lower(): + architecture = "MistralForCausalLM" + else: hf_config = AutoConfig.from_pretrained(official_model_name, **kwargs) architecture = hf_config.architectures[0] - else: - architecture = "LlamaForCausalLM" if official_model_name.startswith( ("llama-7b", "meta-llama/Llama-2-7b") ): # same architecture for LLaMA and Llama-2 @@ -802,6 +808,26 @@ def convert_hf_model_config(model_name: str, **kwargs): "act_fn": "gelu", "attention_dir": "bidirectional", } + elif architecture == "MistralForCausalLM": + cfg_dict = { + "d_model": 4096, + "d_head": 4096 // 32, + "n_heads": 32, + "d_mlp": 14336, + "n_layers": 32, + "n_ctx": 32768, + "d_vocab": 32000, + "act_fn": "silu", + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "window_size": 4096, + "attn_types": ["local"] * 32, + "eps": 1e-05, + "n_key_value_heads": 8, + "gated_mlp": True, + "use_local_attn": True, + "rotary_dim": 4096 // 32, + } elif architecture == "BloomForCausalLM": cfg_dict = { "d_model": hf_config.hidden_size, @@ -817,7 +843,6 @@ def convert_hf_model_config(model_name: str, **kwargs): "post_embedding_ln": True, "positional_embedding_type": "alibi", } - elif architecture == "GPT2LMHeadCustomModel": # santacoder cfg_dict = { @@ -1195,6 +1220,8 @@ def get_pretrained_state_dict( state_dict = convert_llama_weights(hf_model, cfg) elif cfg.original_architecture == "BertForMaskedLM": state_dict = convert_bert_weights(hf_model, cfg) + elif cfg.original_architecture == "MistralForCausalLM": + state_dict = convert_mistral_weights(hf_model, cfg) elif cfg.original_architecture == "BloomForCausalLM": state_dict = convert_bloom_weights(hf_model, cfg) elif cfg.original_architecture == "GPT2LMHeadCustomModel": @@ -1599,6 +1626,66 @@ def convert_qwen_weights(qwen, cfg: HookedTransformerConfig): return state_dict +def convert_mistral_weights(mistral, cfg: HookedTransformerConfig): + state_dict = {} + + state_dict["embed.W_E"] = mistral.model.embed_tokens.weight + + # Mistral has no biases anywhere + for l in range(cfg.n_layers): + state_dict[f"blocks.{l}.ln1.w"] = mistral.model.layers[l].input_layernorm.weight + + W_Q = mistral.model.layers[l].self_attn.q_proj.weight + W_K = mistral.model.layers[l].self_attn.k_proj.weight + W_V = mistral.model.layers[l].self_attn.v_proj.weight + W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads) + W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads) + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn._W_K"] = W_K + state_dict[f"blocks.{l}.attn._W_V"] = W_V + + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( + cfg.n_heads, cfg.d_head, dtype=cfg.dtype + ) + state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( + cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype + ) + state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( + cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype + ) + + W_O = mistral.model.layers[l].self_attn.o_proj.weight + W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_O"] = W_O + + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.ln2.w"] = mistral.model.layers[ + l + ].post_attention_layernorm.weight + + state_dict[f"blocks.{l}.mlp.W_in"] = mistral.model.layers[ + l + ].mlp.up_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_gate"] = mistral.model.layers[ + l + ].mlp.gate_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.mlp.W_out"] = mistral.model.layers[ + l + ].mlp.down_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict["ln_final.w"] = mistral.model.norm.weight + + state_dict["unembed.W_U"] = mistral.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) + + return state_dict + + def convert_opt_weights(opt, cfg: HookedTransformerConfig): state_dict = {} diff --git a/transformer_lens/past_key_value_caching.py b/transformer_lens/past_key_value_caching.py index f80ea2e4d..cff973191 100644 --- a/transformer_lens/past_key_value_caching.py +++ b/transformer_lens/past_key_value_caching.py @@ -27,12 +27,15 @@ def init_cache_entry( device: Union[torch.device, str, None], batch_size: int = 1, ): + n_heads = ( + cfg.n_key_value_heads if cfg.n_key_value_heads is not None else cfg.n_heads + ) return cls( past_keys=torch.empty( - (batch_size, 0, cfg.n_heads, cfg.d_head), device=device, dtype=cfg.dtype + (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype ), past_values=torch.empty( - (batch_size, 0, cfg.n_heads, cfg.d_head), device=device, dtype=cfg.dtype + (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype ), ) From 19b3bc8a23791692d443d347baaebc27568ac3c4 Mon Sep 17 00:00:00 2001 From: Collin Date: Tue, 23 Jan 2024 15:22:57 -0800 Subject: [PATCH 32/73] Implement RMS Layer Norm folding (#489) * add rms norm folding to hookedtransformer * cleanup, add test * formatting * address comments * fix docstring typo * format again * fix convert_llama_weights device issue * add support for folding models w/ gqa * formatting --- tests/acceptance/test_hooked_transformer.py | 73 +++++++ transformer_lens/HookedTransformer.py | 219 ++++++++++++-------- transformer_lens/loading_from_pretrained.py | 4 +- 3 files changed, 211 insertions(+), 85 deletions(-) diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index e60639e57..f1ee955fa 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -167,6 +167,79 @@ def test_from_pretrained_revision(): raise AssertionError("Should have raised an error") +def check_norm_folding( + model_name, + hf_model=None, + tokenizer=None, + prompt="Hello, world!", + device=None, + dtype=None, +): + """ + Checks that loading a model with Layer/RMS Norm folding enabled does not (significantly) change its outputs. + + Returns the maximum difference between the logits produced by the same model with and without norm folding enabled. + + Also asserts that this difference is within some tolerance, although this is deliberately set to a high value + in order to account for lower precision models. + """ + + # If a device/dtype is not specified, and hf_model is provided, use its device/dtype + # Otherwise, default to cuda (if available)/float32 + if device is None: + if hf_model: + device = hf_model.device + else: + device = "cuda" if torch.cuda.is_available() else "cpu" + if dtype is None: + if hf_model: + dtype = hf_model.dtype + else: + dtype = "float32" + + folded_model = HookedTransformer.from_pretrained( + model_name=model_name, + hf_model=hf_model, + device=device, + tokenizer=tokenizer, + dtype=dtype, + fold_ln=True, + center_writing_weights=False, + center_unembed=False, + ) + tokens = folded_model.to_tokens(prompt) + folded_logits = folded_model(tokens).detach() + del folded_model + torch.cuda.empty_cache() + + unfolded_model = HookedTransformer.from_pretrained( + model_name=model_name, + hf_model=hf_model, + device=device, + tokenizer=tokenizer, + dtype=dtype, + fold_ln=False, + center_writing_weights=False, + center_unembed=False, + ) + unfolded_logits = unfolded_model(tokens).detach() + del unfolded_model + torch.cuda.empty_cache() + + assert torch.allclose( + torch.softmax(folded_logits, dim=-1), + torch.softmax(unfolded_logits, dim=-1), + atol=1e-2, + ) + + return torch.max( + torch.abs( + torch.softmax(folded_logits, dim=-1) + - torch.softmax(unfolded_logits, dim=-1) + ) + ) + + def check_similarity_with_hf_model(tl_model, hf_model, prompt="Hello, world!"): """ Check that the TransformerLens model and the HuggingFace model diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 60379e453..5cce47eec 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1423,14 +1423,17 @@ def load_and_process_state_dict( state_dict = self.fill_missing_keys(state_dict) if fold_ln: - if self.cfg.normalization_type not in ["LN", "LNPre"]: - logging.warning( - "You are not using LayerNorm, so the layer norm weights can't be folded! Skipping" + if self.cfg.normalization_type in ["LN", "LNPre"]: + state_dict = self.fold_layer_norm(state_dict) + elif self.cfg.normalization_type in ["RMS", "RMSPre"]: + state_dict = self.fold_layer_norm( + state_dict, fold_biases=False, center_weights=False ) else: - # Note - you can run fold_layer_norm while normalization_type is LN, but this is not advised! It mostly - # goes wrong when you're training the model. - state_dict = self.fold_layer_norm(state_dict) + logging.warning( + "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping" + ) + if center_writing_weights: if self.cfg.normalization_type not in ["LN", "LNPre"]: logging.warning( @@ -1442,19 +1445,23 @@ def load_and_process_state_dict( ) else: state_dict = self.center_writing_weights(state_dict) + if center_unembed: state_dict = self.center_unembed(state_dict) if fold_value_biases: state_dict = self.fold_value_biases(state_dict) if refactor_factored_attn_matrices: state_dict = self.refactor_factored_attn_matrices(state_dict) - self.load_state_dict(state_dict) + + self.load_state_dict(state_dict, strict=False) def fill_missing_keys(self, state_dict): return loading.fill_missing_keys(self, state_dict) - def fold_layer_norm(self, state_dict: Dict[str, torch.Tensor]): - """Fold Layer Norm. + def fold_layer_norm( + self, state_dict: Dict[str, torch.Tensor], fold_biases=True, center_weights=True + ): + """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False. Takes in a state dict from a pretrained model, formatted to be consistent with HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring @@ -1462,132 +1469,167 @@ def fold_layer_norm(self, state_dict: Dict[str, torch.Tensor]): Args: state_dict (Dict[str, torch.Tensor]): State dict of pretrained model. + fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used. + center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used. """ + + # Models that use Grouped Query Attention (Only Mistral at the time of writing) prefix their K/V weights and + # biases with an underscore in order to distinguish them, but folding the LN into them still works the same, + # so we just add the underscore if GQA is used (i.e. if `cfg.n_key_value_heads is specified`). + gqa = "" if self.cfg.n_key_value_heads is None else "_" + for l in range(self.cfg.n_layers): # Fold ln1 into attention - it's important to fold biases first, since biases depend on # weights but not vice versa The various indexing is just to broadcast ln.b and ln.w # along every axis other than d_model. Each weight matrix right multiplies. To fold in # the bias, we use the W_ matrix to map it to the hidden space of the layer, so we need # to sum along axis -2, which is the residual stream space axis. - state_dict[f"blocks.{l}.attn.b_Q"] = state_dict[f"blocks.{l}.attn.b_Q"] + ( - state_dict[f"blocks.{l}.attn.W_Q"] - * state_dict[f"blocks.{l}.ln1.b"][None, :, None] - ).sum(-2) - state_dict[f"blocks.{l}.attn.b_K"] = state_dict[f"blocks.{l}.attn.b_K"] + ( - state_dict[f"blocks.{l}.attn.W_K"] - * state_dict[f"blocks.{l}.ln1.b"][None, :, None] - ).sum(-2) - state_dict[f"blocks.{l}.attn.b_V"] = state_dict[f"blocks.{l}.attn.b_V"] + ( - state_dict[f"blocks.{l}.attn.W_V"] - * state_dict[f"blocks.{l}.ln1.b"][None, :, None] - ).sum(-2) + if fold_biases: + state_dict[f"blocks.{l}.attn.b_Q"] = state_dict[ + f"blocks.{l}.attn.b_Q" + ] + ( + state_dict[f"blocks.{l}.attn.W_Q"] + * state_dict[f"blocks.{l}.ln1.b"][None, :, None] + ).sum( + -2 + ) + state_dict[f"blocks.{l}.attn.{gqa}b_K"] = state_dict[ + f"blocks.{l}.attn.{gqa}b_K" + ] + ( + state_dict[f"blocks.{l}.attn.{gqa}W_K"] + * state_dict[f"blocks.{l}.ln1.b"][None, :, None] + ).sum( + -2 + ) + state_dict[f"blocks.{l}.attn.{gqa}b_V"] = state_dict[ + f"blocks.{l}.attn.{gqa}b_V" + ] + ( + state_dict[f"blocks.{l}.attn.{gqa}W_V"] + * state_dict[f"blocks.{l}.ln1.b"][None, :, None] + ).sum( + -2 + ) + del state_dict[f"blocks.{l}.ln1.b"] state_dict[f"blocks.{l}.attn.W_Q"] = ( state_dict[f"blocks.{l}.attn.W_Q"] * state_dict[f"blocks.{l}.ln1.w"][None, :, None] ) - state_dict[f"blocks.{l}.attn.W_K"] = ( - state_dict[f"blocks.{l}.attn.W_K"] + state_dict[f"blocks.{l}.attn.{gqa}W_K"] = ( + state_dict[f"blocks.{l}.attn.{gqa}W_K"] * state_dict[f"blocks.{l}.ln1.w"][None, :, None] ) - state_dict[f"blocks.{l}.attn.W_V"] = ( - state_dict[f"blocks.{l}.attn.W_V"] + state_dict[f"blocks.{l}.attn.{gqa}W_V"] = ( + state_dict[f"blocks.{l}.attn.{gqa}W_V"] * state_dict[f"blocks.{l}.ln1.w"][None, :, None] ) + del state_dict[f"blocks.{l}.ln1.w"] # Finally, we center the weights reading from the residual stream. The output of the # first part of the LayerNorm is mean 0 and standard deviation 1, so the mean of any # input vector of the matrix doesn't matter and can be set to zero. Equivalently, the # output of LayerNormPre is orthogonal to the vector of all 1s (because dotting with # that gets the sum), so we can remove the component of the matrix parallel to this. - state_dict[f"blocks.{l}.attn.W_Q"] -= einops.reduce( - state_dict[f"blocks.{l}.attn.W_Q"], - "head_index d_model d_head -> head_index 1 d_head", - "mean", - ) - state_dict[f"blocks.{l}.attn.W_K"] -= einops.reduce( - state_dict[f"blocks.{l}.attn.W_K"], - "head_index d_model d_head -> head_index 1 d_head", - "mean", - ) - state_dict[f"blocks.{l}.attn.W_V"] -= einops.reduce( - state_dict[f"blocks.{l}.attn.W_V"], - "head_index d_model d_head -> head_index 1 d_head", - "mean", - ) - - del ( - state_dict[f"blocks.{l}.ln1.w"], - state_dict[f"blocks.{l}.ln1.b"], - ) + if center_weights: + state_dict[f"blocks.{l}.attn.W_Q"] -= einops.reduce( + state_dict[f"blocks.{l}.attn.W_Q"], + "head_index d_model d_head -> head_index 1 d_head", + "mean", + ) + state_dict[f"blocks.{l}.attn.{gqa}W_K"] -= einops.reduce( + state_dict[f"blocks.{l}.attn.{gqa}W_K"], + "head_index d_model d_head -> head_index 1 d_head", + "mean", + ) + state_dict[f"blocks.{l}.attn.{gqa}W_V"] -= einops.reduce( + state_dict[f"blocks.{l}.attn.{gqa}W_V"], + "head_index d_model d_head -> head_index 1 d_head", + "mean", + ) # Fold ln2 into MLP if not self.cfg.attn_only: - state_dict[f"blocks.{l}.mlp.b_in"] = state_dict[ - f"blocks.{l}.mlp.b_in" - ] + ( - state_dict[f"blocks.{l}.mlp.W_in"] - * state_dict[f"blocks.{l}.ln2.b"][:, None] - ).sum( - -2 - ) + if fold_biases: + state_dict[f"blocks.{l}.mlp.b_in"] = state_dict[ + f"blocks.{l}.mlp.b_in" + ] + ( + state_dict[f"blocks.{l}.mlp.W_in"] + * state_dict[f"blocks.{l}.ln2.b"][:, None] + ).sum( + -2 + ) + del state_dict[f"blocks.{l}.ln2.b"] + state_dict[f"blocks.{l}.mlp.W_in"] = ( state_dict[f"blocks.{l}.mlp.W_in"] * state_dict[f"blocks.{l}.ln2.w"][:, None] ) - # Center the weights that read in from the LayerNormPre - state_dict[f"blocks.{l}.mlp.W_in"] -= einops.reduce( - state_dict[f"blocks.{l}.mlp.W_in"], - "d_model d_mlp -> 1 d_mlp", - "mean", - ) + if self.cfg.gated_mlp: + state_dict[f"blocks.{l}.mlp.W_gate"] = ( + state_dict[f"blocks.{l}.mlp.W_gate"] + * state_dict[f"blocks.{l}.ln2.w"][:, None] + ) + + del state_dict[f"blocks.{l}.ln2.w"] - del state_dict[f"blocks.{l}.ln2.w"], state_dict[f"blocks.{l}.ln2.b"] + if center_weights: + # Center the weights that read in from the LayerNormPre + state_dict[f"blocks.{l}.mlp.W_in"] -= einops.reduce( + state_dict[f"blocks.{l}.mlp.W_in"], + "d_model d_mlp -> 1 d_mlp", + "mean", + ) if self.cfg.act_fn.startswith("solu"): # Fold ln3 into activation - state_dict[f"blocks.{l}.mlp.b_out"] = state_dict[ - f"blocks.{l}.mlp.b_out" - ] + ( - state_dict[f"blocks.{l}.mlp.W_out"] - * state_dict[f"blocks.{l}.mlp.ln.b"][:, None] - ).sum( - -2 - ) + if fold_biases: + state_dict[f"blocks.{l}.mlp.b_out"] = state_dict[ + f"blocks.{l}.mlp.b_out" + ] + ( + state_dict[f"blocks.{l}.mlp.W_out"] + * state_dict[f"blocks.{l}.mlp.ln.b"][:, None] + ).sum( + -2 + ) + + del state_dict[f"blocks.{l}.mlp.ln.b"] + state_dict[f"blocks.{l}.mlp.W_out"] = ( state_dict[f"blocks.{l}.mlp.W_out"] * state_dict[f"blocks.{l}.mlp.ln.w"][:, None] ) - # Center the weights that read in from the LayerNormPre - state_dict[f"blocks.{l}.mlp.W_out"] -= einops.reduce( - state_dict[f"blocks.{l}.mlp.W_out"], - "d_mlp d_model -> 1 d_model", - "mean", - ) - del ( - state_dict[f"blocks.{l}.mlp.ln.w"], - state_dict[f"blocks.{l}.mlp.ln.b"], - ) + if center_weights: + # Center the weights that read in from the LayerNormPre + state_dict[f"blocks.{l}.mlp.W_out"] -= einops.reduce( + state_dict[f"blocks.{l}.mlp.W_out"], + "d_mlp d_model -> 1 d_model", + "mean", + ) + + del state_dict[f"blocks.{l}.mlp.ln.w"] + # Fold ln_final into Unembed - if not self.cfg.final_rms: + if not self.cfg.final_rms and fold_biases: # Dumb bug from my old SoLU training code, some models have RMSNorm instead of LayerNorm # pre unembed. state_dict[f"unembed.b_U"] = state_dict[f"unembed.b_U"] + ( state_dict[f"unembed.W_U"] * state_dict[f"ln_final.b"][:, None] ).sum(dim=-2) del state_dict[f"ln_final.b"] + state_dict[f"unembed.W_U"] = ( state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None] ) + del state_dict[f"ln_final.w"] - # Center the weights that read in from the LayerNormPre - state_dict[f"unembed.W_U"] -= einops.reduce( - state_dict[f"unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean" - ) + if center_weights: + # Center the weights that read in from the LayerNormPre + state_dict[f"unembed.W_U"] -= einops.reduce( + state_dict[f"unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean" + ) - del state_dict[f"ln_final.w"] return state_dict def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]): @@ -1817,6 +1859,15 @@ def process_weights_( layer.ln2 = LayerNormPre(self.cfg) if self.cfg.act_fn.endswith("_ln"): layer.mlp.ln = LayerNormPre(self.cfg) + elif fold_ln and self.cfg.normalization_type == "RMS": + # We do the same for RMSNorm if used + self.cfg.normalization_type = "RMSPre" + self.ln_final = RMSNormPre(self.cfg) + for layer in self.blocks: + layer.ln1 = RMSNormPre(self.cfg) + layer.ln2 = RMSNormPre(self.cfg) + if self.cfg.act_fn.endswith("_ln"): + layer.mlp.ln = RMSNormPre(self.cfg) self.load_and_process_state_dict( state_dict, diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index c75ac3748..cb81276c3 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1019,6 +1019,8 @@ def get_pretrained_model_config( if fold_ln: if cfg_dict["normalization_type"] in ["LN", "LNPre"]: cfg_dict["normalization_type"] = "LNPre" + elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]: + cfg_dict["normalization_type"] = "RMSPre" else: logging.warning("Cannot fold in layer norm, normalization_type is not LN.") @@ -1530,7 +1532,7 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig): W_O = llama.model.layers[l].self_attn.o_proj.weight W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) - state_dict[f"blocks.{l}.attn.W_O"] = W_O + state_dict[f"blocks.{l}.attn.W_O"] = W_O.to(device=cfg.device) state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros( cfg.d_model, dtype=cfg.dtype, device=cfg.device From ba3fb3bd082083ff75a9b2d9d1c4e16d127c116d Mon Sep 17 00:00:00 2001 From: Collin Date: Sun, 28 Jan 2024 05:44:49 -0800 Subject: [PATCH 33/73] Cap Mistral's context length at 2k (#495) Temporary fix to prevent multiple TB of memory allocated just to attention masks --- transformer_lens/loading_from_pretrained.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index cb81276c3..9506beeb6 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -815,7 +815,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "n_heads": 32, "d_mlp": 14336, "n_layers": 32, - "n_ctx": 32768, + "n_ctx": 2048, # Capped due to memory issues "d_vocab": 32000, "act_fn": "silu", "normalization_type": "RMS", From 8a17a76ede0fd5d29f39705366325da0f8c2818f Mon Sep 17 00:00:00 2001 From: cmathw <108584265+cmathw@users.noreply.github.com> Date: Sun, 28 Jan 2024 13:49:17 +0000 Subject: [PATCH 34/73] Add Microsoft Phi models support (#484) --- tests/acceptance/test_hooked_transformer.py | 3 + transformer_lens/HookedTransformer.py | 7 ++ transformer_lens/loading_from_pretrained.py | 106 +++++++++++++++++++- 3 files changed, 115 insertions(+), 1 deletion(-) diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index f1ee955fa..0f1cbb438 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -33,6 +33,9 @@ "tiny-stories-33M", "bloom-560m", "santacoder", + "microsoft/phi-1", + "microsoft/phi-1_5", + "microsoft/phi-2", ] text = "Hello world!" """ diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 5cce47eec..5dca72203 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -124,11 +124,18 @@ def __init__( f"{self.cfg.tokenizer_name} tokenizer not loaded. Please load manually." ) else: + # Hugging Face defaults to use_fast to True + use_fast = True + # Phi model's fast tokenizer does not support adding a BOS token, use_fast + # should be False + if "phi" in self.cfg.tokenizer_name.lower(): + use_fast = False self.set_tokenizer( AutoTokenizer.from_pretrained( self.cfg.tokenizer_name, add_bos_token=True, trust_remote_code=self.cfg.trust_remote_code, + use_fast=use_fast, ), default_padding_side=default_padding_side, ) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 9506beeb6..335aa078d 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -153,6 +153,9 @@ "Qwen/Qwen-1_8B-Chat", "Qwen/Qwen-7B-Chat", "Qwen/Qwen-14B-Chat", + "microsoft/phi-1", + "microsoft/phi-1_5", + "microsoft/phi-2", ] """Official model names for models on HuggingFace.""" @@ -540,6 +543,9 @@ "Qwen/Qwen-1_8B-Chat": ["qwen-1.8b-chat"], "Qwen/Qwen-7B-Chat": ["qwen-7b-chat"], "Qwen/Qwen-14B-Chat": ["qwen-14b-chat"], + "microsoft/phi-1": ["phi-1"], + "microsoft/phi-1_5": ["phi-1_5"], + "microsoft/phi-2": ["phi-2"], } """Model aliases for models on HuggingFace.""" @@ -557,7 +563,13 @@ for name in OFFICIAL_MODEL_NAMES ] -NEED_REMOTE_CODE_MODELS = ("bigcode/santacoder", "Qwen/Qwen-") +NEED_REMOTE_CODE_MODELS = ( + "bigcode/santacoder", + "Qwen/Qwen-", + "microsoft/phi-1", + "microsoft/phi-1_5", + "microsoft/phi-2", +) def make_model_alias_map(): @@ -884,6 +896,29 @@ def convert_hf_model_config(model_name: str, **kwargs): "final_rms": True, "gated_mlp": True, } + elif architecture == "PhiForCausalLM": + # Architecture for microsoft/phi models + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": hf_config.max_position_embeddings, + "eps": hf_config.layer_norm_eps, + "d_vocab": hf_config.vocab_size, + "act_fn": hf_config.hidden_act, + "initializer_range": hf_config.initializer_range, + "normalization_type": "LN", + "positional_embedding_type": "rotary", + "trust_remote_code": True, + "rotary_base": hf_config.rope_theta, + "use_attn_scale": True, + "parallel_attn_mlp": True, + } + partial_rotary_factor = hf_config.partial_rotary_factor + cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"]) + else: raise NotImplementedError(f"{architecture} is not currently supported.") # All of these models use LayerNorm @@ -1230,6 +1265,8 @@ def get_pretrained_state_dict( state_dict = convert_coder_weights(hf_model, cfg) elif cfg.original_architecture == "QWenLMHeadModel": state_dict = convert_qwen_weights(hf_model, cfg) + elif cfg.original_architecture == "PhiForCausalLM": + state_dict = convert_phi_weights(hf_model, cfg) else: raise ValueError( f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." @@ -2165,6 +2202,73 @@ def convert_coder_weights(model, cfg: HookedTransformerConfig): return state_dict +def convert_phi_weights(phi, cfg: HookedTransformerConfig): + state_dict = {} + + state_dict["embed.W_E"] = phi.model.embed_tokens.weight + + for l in range(cfg.n_layers): + state_dict[f"blocks.{l}.ln1.w"] = phi.model.layers[l].input_layernorm.weight + state_dict[f"blocks.{l}.ln1.b"] = phi.model.layers[l].input_layernorm.bias + + W_Q = phi.model.layers[l].self_attn.q_proj.weight + W_K = phi.model.layers[l].self_attn.k_proj.weight + W_V = phi.model.layers[l].self_attn.v_proj.weight + W_Q = einops.rearrange( + W_Q, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads + ) + W_K = einops.rearrange( + W_K, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads + ) + W_V = einops.rearrange( + W_V, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads + ) + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn.W_K"] = W_K + state_dict[f"blocks.{l}.attn.W_V"] = W_V + + b_Q = phi.model.layers[l].self_attn.q_proj.bias + b_K = phi.model.layers[l].self_attn.k_proj.bias + b_V = phi.model.layers[l].self_attn.v_proj.bias + b_Q = einops.rearrange( + b_Q, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads + ) + b_K = einops.rearrange( + b_K, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads + ) + b_V = einops.rearrange( + b_V, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads + ) + state_dict[f"blocks.{l}.attn.b_Q"] = b_Q + state_dict[f"blocks.{l}.attn.b_K"] = b_K + state_dict[f"blocks.{l}.attn.b_V"] = b_V + + W_O = phi.model.layers[l].self_attn.dense.weight + W_O = einops.rearrange( + W_O, "d_model (n_head d_head) -> n_head d_head d_model", n_head=cfg.n_heads + ) + + state_dict[f"blocks.{l}.attn.W_O"] = W_O + state_dict[f"blocks.{l}.attn.b_O"] = phi.model.layers[l].self_attn.dense.bias + + # Layer Norm 1 and 2 are tied. + state_dict[f"blocks.{l}.ln2.w"] = state_dict[f"blocks.{l}.ln1.w"] + state_dict[f"blocks.{l}.ln2.b"] = state_dict[f"blocks.{l}.ln1.b"] + + state_dict[f"blocks.{l}.mlp.W_in"] = phi.model.layers[l].mlp.fc1.weight.T + state_dict[f"blocks.{l}.mlp.b_in"] = phi.model.layers[l].mlp.fc1.bias + state_dict[f"blocks.{l}.mlp.W_out"] = phi.model.layers[l].mlp.fc2.weight.T + state_dict[f"blocks.{l}.mlp.b_out"] = phi.model.layers[l].mlp.fc2.bias + + state_dict["ln_final.w"] = phi.model.final_layernorm.weight + state_dict["ln_final.b"] = phi.model.final_layernorm.bias + + state_dict["unembed.W_U"] = phi.lm_head.weight.T + state_dict["unembed.b_U"] = phi.lm_head.bias + + return state_dict + + @dataclasses.dataclass class Config: d_model: int = 768 From 829084a53836c5b8b388aa37a5ffce73b6371712 Mon Sep 17 00:00:00 2001 From: adamkarvonen <85900742+adamkarvonen@users.noreply.github.com> Date: Sun, 28 Jan 2024 07:50:46 -0600 Subject: [PATCH 35/73] Fix a redundant MLP bias assignment (#485) --- transformer_lens/loading_from_pretrained.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 335aa078d..d2bbf6f08 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -2005,13 +2005,6 @@ def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig): f"{layer_key}.attn.c_proj.bias" ] - new_state_dict[f"blocks.{layer}.mlp.b_in"] = old_state_dict[ - f"{layer_key}.mlp.c_fc.bias" - ].T - new_state_dict[f"blocks.{layer}.mlp.b_out"] = old_state_dict[ - f"{layer_key}.mlp.c_proj.bias" - ].T - return new_state_dict From 109fd99900569610bfcace8d04bf85f768288676 Mon Sep 17 00:00:00 2001 From: Andy Arditi Date: Thu, 7 Mar 2024 10:55:14 -0800 Subject: [PATCH 36/73] add qwen1.5 models (#507) * add qwen1.5 models * comments --- demos/Qwen.ipynb | 318 ++++++++++---------- poetry.lock | 257 ++-------------- pyproject.toml | 2 +- transformer_lens/loading_from_pretrained.py | 121 ++++++++ 4 files changed, 306 insertions(+), 392 deletions(-) diff --git a/demos/Qwen.ipynb b/demos/Qwen.ipynb index d49b39578..e8ef18f57 100644 --- a/demos/Qwen.ipynb +++ b/demos/Qwen.ipynb @@ -9,66 +9,71 @@ "name": "stdout", "output_type": "stream", "text": [ - "Collecting tiktoken\n", - " Downloading tiktoken-0.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)\n", - "Collecting regex>=2022.1.18 (from tiktoken)\n", - " Downloading regex-2023.12.25-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.9/40.9 kB\u001b[0m \u001b[31m718.2 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m \u001b[36m0:00:01\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: requests>=2.26.0 in /opt/conda/lib/python3.10/site-packages (from tiktoken) (2.31.0)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests>=2.26.0->tiktoken) (2.0.4)\n", - "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests>=2.26.0->tiktoken) (3.4)\n", - "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests>=2.26.0->tiktoken) (1.26.18)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests>=2.26.0->tiktoken) (2023.11.17)\n", - "Downloading tiktoken-0.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", - "\u001b[?25hDownloading regex-2023.12.25-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (773 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m774.0/774.0 kB\u001b[0m \u001b[31m11.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", - "\u001b[?25hInstalling collected packages: regex, tiktoken\n", - "Successfully installed regex-2023.12.25 tiktoken-0.5.2\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n", - "Collecting transformers_stream_generator\n", - " Downloading transformers-stream-generator-0.0.4.tar.gz (12 kB)\n", - " Preparing metadata (setup.py) ... \u001b[?25ldone\n", - "\u001b[?25hCollecting transformers>=4.26.1 (from transformers_stream_generator)\n", - " Downloading transformers-4.36.2-py3-none-any.whl.metadata (126 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m126.8/126.8 kB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from transformers>=4.26.1->transformers_stream_generator) (3.13.1)\n", - "Collecting huggingface-hub<1.0,>=0.19.3 (from transformers>=4.26.1->transformers_stream_generator)\n", - " Downloading huggingface_hub-0.20.2-py3-none-any.whl.metadata (12 kB)\n", - "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from transformers>=4.26.1->transformers_stream_generator) (1.26.2)\n", - "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from transformers>=4.26.1->transformers_stream_generator) (23.1)\n", - "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from transformers>=4.26.1->transformers_stream_generator) (6.0.1)\n", - "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.10/site-packages (from transformers>=4.26.1->transformers_stream_generator) (2023.12.25)\n", - "Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from transformers>=4.26.1->transformers_stream_generator) (2.31.0)\n", - "Collecting tokenizers<0.19,>=0.14 (from transformers>=4.26.1->transformers_stream_generator)\n", - " Downloading tokenizers-0.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\n", - "Collecting safetensors>=0.3.1 (from transformers>=4.26.1->transformers_stream_generator)\n", - " Downloading safetensors-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)\n", - "Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.10/site-packages (from transformers>=4.26.1->transformers_stream_generator) (4.65.0)\n", - "Requirement already satisfied: fsspec>=2023.5.0 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.19.3->transformers>=4.26.1->transformers_stream_generator) (2023.12.2)\n", - "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.19.3->transformers>=4.26.1->transformers_stream_generator) (4.7.1)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->transformers>=4.26.1->transformers_stream_generator) (2.0.4)\n", - "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->transformers>=4.26.1->transformers_stream_generator) (3.4)\n", - "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->transformers>=4.26.1->transformers_stream_generator) (1.26.18)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->transformers>=4.26.1->transformers_stream_generator) (2023.11.17)\n", - "Downloading transformers-4.36.2-py3-none-any.whl (8.2 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.2/8.2 MB\u001b[0m \u001b[31m64.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", - "\u001b[?25hDownloading huggingface_hub-0.20.2-py3-none-any.whl (330 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m330.3/330.3 kB\u001b[0m \u001b[31m28.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading safetensors-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m63.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading tokenizers-0.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.8/3.8 MB\u001b[0m \u001b[31m77.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m\n", - "\u001b[?25hBuilding wheels for collected packages: transformers_stream_generator\n", - " Building wheel for transformers_stream_generator (setup.py) ... \u001b[?25ldone\n", - "\u001b[?25h Created wheel for transformers_stream_generator: filename=transformers_stream_generator-0.0.4-py3-none-any.whl size=12315 sha256=44d1037124d6e69b847e846035b01ac56e5ebf6d4b115a332c16e85d50c4dc42\n", - " Stored in directory: /root/.cache/pip/wheels/47/1d/3c/92d88493ed40c0d9be60a391eb76c9a56e9f9b7542cb789401\n", - "Successfully built transformers_stream_generator\n", - "Installing collected packages: safetensors, huggingface-hub, tokenizers, transformers, transformers_stream_generator\n", - "Successfully installed huggingface-hub-0.20.2 safetensors-0.4.1 tokenizers-0.15.0 transformers-4.36.2 transformers_stream_generator-0.0.4\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n" + "Requirement already satisfied: transformers_stream_generator in /root/TransformerLens/.venv/lib/python3.10/site-packages (0.0.4)\n", + "Requirement already satisfied: plotly in /root/TransformerLens/.venv/lib/python3.10/site-packages (5.18.0)\n", + "Requirement already satisfied: circuitsvis in /root/TransformerLens/.venv/lib/python3.10/site-packages (1.43.2)\n", + "Requirement already satisfied: huggingface_hub in /root/TransformerLens/.venv/lib/python3.10/site-packages (0.20.2)\n", + "Requirement already satisfied: einops in /root/TransformerLens/.venv/lib/python3.10/site-packages (0.7.0)\n", + "Requirement already satisfied: tiktoken in /root/TransformerLens/.venv/lib/python3.10/site-packages (0.5.2)\n", + "Requirement already satisfied: datasets in /root/TransformerLens/.venv/lib/python3.10/site-packages (2.14.4)\n", + "Requirement already satisfied: transformers>=4.26.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from transformers_stream_generator) (4.37.2)\n", + "Requirement already satisfied: tenacity>=6.2.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from plotly) (8.2.3)\n", + "Requirement already satisfied: packaging in /root/TransformerLens/.venv/lib/python3.10/site-packages (from plotly) (23.2)\n", + "Requirement already satisfied: importlib-metadata>=5.1.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (7.0.1)\n", + "Requirement already satisfied: numpy>=1.24 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (1.26.3)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (2.18.1)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: torch>=1.10 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (2.1.2)\n", + "Requirement already satisfied: triton==2.1.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from circuitsvis) (2.1.0)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->circuitsvis) (12.3.101)\n", + "Requirement already satisfied: filelock in /root/TransformerLens/.venv/lib/python3.10/site-packages (from triton==2.1.0->circuitsvis) (3.13.1)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from huggingface_hub) (2023.12.2)\n", + "Requirement already satisfied: requests in /root/TransformerLens/.venv/lib/python3.10/site-packages (from huggingface_hub) (2.31.0)\n", + "Requirement already satisfied: tqdm>=4.42.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from huggingface_hub) (4.66.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from huggingface_hub) (6.0.1)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from huggingface_hub) (4.9.0)\n", + "Requirement already satisfied: regex>=2022.1.18 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from tiktoken) (2023.12.25)\n", + "Requirement already satisfied: pyarrow>=8.0.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from datasets) (14.0.2)\n", + "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from datasets) (0.3.7)\n", + "Requirement already satisfied: pandas in /root/TransformerLens/.venv/lib/python3.10/site-packages (from datasets) (2.0.3)\n", + "Requirement already satisfied: xxhash in /root/TransformerLens/.venv/lib/python3.10/site-packages (from datasets) (3.4.1)\n", + "Requirement already satisfied: multiprocess in /root/TransformerLens/.venv/lib/python3.10/site-packages (from datasets) (0.70.15)\n", + "Requirement already satisfied: aiohttp in /root/TransformerLens/.venv/lib/python3.10/site-packages (from datasets) (3.9.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from aiohttp->datasets) (23.2.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from aiohttp->datasets) (6.0.4)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from aiohttp->datasets) (1.9.4)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from aiohttp->datasets) (1.4.1)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.1)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from aiohttp->datasets) (4.0.3)\n", + "Requirement already satisfied: zipp>=0.5 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from importlib-metadata>=5.1.0->circuitsvis) (3.17.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from requests->huggingface_hub) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from requests->huggingface_hub) (3.6)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from requests->huggingface_hub) (2.1.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from requests->huggingface_hub) (2023.11.17)\n", + "Requirement already satisfied: sympy in /root/TransformerLens/.venv/lib/python3.10/site-packages (from torch>=1.10->circuitsvis) (1.12)\n", + "Requirement already satisfied: networkx in /root/TransformerLens/.venv/lib/python3.10/site-packages (from torch>=1.10->circuitsvis) (3.1)\n", + "Requirement already satisfied: jinja2 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from torch>=1.10->circuitsvis) (3.1.2)\n", + "Requirement already satisfied: tokenizers<0.19,>=0.14 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from transformers>=4.26.1->transformers_stream_generator) (0.15.0)\n", + "Requirement already satisfied: safetensors>=0.4.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from transformers>=4.26.1->transformers_stream_generator) (0.4.1)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from pandas->datasets) (2023.3.post1)\n", + "Requirement already satisfied: tzdata>=2022.1 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from pandas->datasets) (2023.4)\n", + "Requirement already satisfied: six>=1.5 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from jinja2->torch>=1.10->circuitsvis) (2.1.3)\n", + "Requirement already satisfied: mpmath>=0.19 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from sympy->torch>=1.10->circuitsvis) (1.3.0)\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n" ] } ], @@ -78,7 +83,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -92,9 +97,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/tmp/ipykernel_11422/410710250.py:21: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + "/tmp/ipykernel_13850/410710250.py:21: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", " ipython.magic(\"load_ext autoreload\")\n", - "/tmp/ipykernel_11422/410710250.py:22: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + "/tmp/ipykernel_13850/410710250.py:22: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", " ipython.magic(\"autoreload 2\")\n" ] } @@ -126,7 +131,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -149,7 +154,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -175,28 +180,53 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def assert_hf_and_tl_model_are_close(\n", + " hf_model,\n", + " tl_model,\n", + " tokenizer,\n", + " prompt=\"This is a prompt to test out\",\n", + " atol=1e-3,\n", + "):\n", + " prompt_toks = tokenizer(prompt, return_tensors=\"pt\").input_ids\n", + "\n", + " hf_logits = hf_model(prompt_toks.to(hf_model.device)).logits\n", + " tl_logits = tl_model(prompt_toks).to(hf_logits)\n", + "\n", + " assert torch.allclose(torch.softmax(hf_logits, dim=-1), torch.softmax(tl_logits, dim=-1), atol=atol)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Qwen, first generation" + ] + }, + { + "cell_type": "code", + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Try importing flash-attention for faster inference...\n", - "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary\n", - "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm\n", - "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency https://github.com/Dao-AILab/flash-attention\n" + "Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "943f344bd8c141738f6f3bd9db5c8514", + "model_id": "2cffaf8715b64623b6799822d7cf1cfe", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Loading checkpoint shards: 0%| | 0/8 [00:00=3.8.0" files = [ @@ -35,7 +34,6 @@ testing = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "parameterized", name = "aiohttp" version = "3.9.1" description = "Async http client/server framework (asyncio)" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -132,7 +130,6 @@ speedups = ["Brotli", "aiodns", "brotlicffi"] name = "aiosignal" version = "1.3.1" description = "aiosignal: a list of registered asynchronous callbacks" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -147,7 +144,6 @@ frozenlist = ">=1.1.0" name = "alabaster" version = "0.7.13" description = "A configurable sidebar-enabled Sphinx theme" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -159,7 +155,6 @@ files = [ name = "anyio" version = "4.2.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -182,7 +177,6 @@ trio = ["trio (>=0.23)"] name = "appdirs" version = "1.4.4" description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." -category = "main" optional = false python-versions = "*" files = [ @@ -194,7 +188,6 @@ files = [ name = "appnope" version = "0.1.3" description = "Disable App Nap on macOS >= 10.9" -category = "dev" optional = false python-versions = "*" files = [ @@ -206,7 +199,6 @@ files = [ name = "argon2-cffi" version = "23.1.0" description = "Argon2 for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -227,7 +219,6 @@ typing = ["mypy"] name = "argon2-cffi-bindings" version = "21.2.0" description = "Low-level CFFI bindings for Argon2" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -265,7 +256,6 @@ tests = ["pytest"] name = "arrow" version = "1.3.0" description = "Better dates & times for Python" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -279,13 +269,12 @@ types-python-dateutil = ">=2.8.10" [package.extras] doc = ["doc8", "sphinx (>=7.0.0)", "sphinx-autobuild", "sphinx-autodoc-typehints", "sphinx_rtd_theme (>=1.3.0)"] -test = ["dateparser (>=1.0.0,<2.0.0)", "pre-commit", "pytest", "pytest-cov", "pytest-mock", "pytz (==2021.1)", "simplejson (>=3.0.0,<4.0.0)"] +test = ["dateparser (==1.*)", "pre-commit", "pytest", "pytest-cov", "pytest-mock", "pytz (==2021.1)", "simplejson (==3.*)"] [[package]] name = "asttokens" version = "2.4.1" description = "Annotate AST trees with source code positions" -category = "dev" optional = false python-versions = "*" files = [ @@ -304,7 +293,6 @@ test = ["astroid (>=1,<2)", "astroid (>=2,<4)", "pytest"] name = "async-lru" version = "2.0.4" description = "Simple LRU cache for asyncio" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -319,7 +307,6 @@ typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} name = "async-timeout" version = "4.0.3" description = "Timeout context manager for asyncio programs" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -331,7 +318,6 @@ files = [ name = "attrs" version = "23.2.0" description = "Classes Without Boilerplate" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -351,7 +337,6 @@ tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "p name = "babel" version = "2.14.0" description = "Internationalization utilities" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -369,7 +354,6 @@ dev = ["freezegun (>=1.0,<2.0)", "pytest (>=6.0)", "pytest-cov"] name = "backcall" version = "0.2.0" description = "Specifications for callback functions passed in to an API" -category = "dev" optional = false python-versions = "*" files = [ @@ -381,7 +365,6 @@ files = [ name = "beartype" version = "0.14.1" description = "Unbearably fast runtime type checking in pure Python." -category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -400,7 +383,6 @@ test-tox-coverage = ["coverage (>=5.5)"] name = "beautifulsoup4" version = "4.12.2" description = "Screen-scraping library" -category = "dev" optional = false python-versions = ">=3.6.0" files = [ @@ -419,7 +401,6 @@ lxml = ["lxml"] name = "better-abc" version = "0.0.3" description = "Python ABC plus abstract attributes" -category = "main" optional = false python-versions = "*" files = [ @@ -431,7 +412,6 @@ files = [ name = "black" version = "23.12.1" description = "The uncompromising code formatter." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -478,7 +458,6 @@ uvloop = ["uvloop (>=0.15.2)"] name = "bleach" version = "6.1.0" description = "An easy safelist-based HTML-sanitizing tool." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -497,7 +476,6 @@ css = ["tinycss2 (>=1.1.0,<1.3)"] name = "certifi" version = "2023.11.17" description = "Python package for providing Mozilla's CA Bundle." -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -509,7 +487,6 @@ files = [ name = "cffi" version = "1.16.0" description = "Foreign Function Interface for Python calling C code." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -574,7 +551,6 @@ pycparser = "*" name = "charset-normalizer" version = "3.3.2" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." -category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -674,7 +650,6 @@ files = [ name = "circuitsvis" version = "1.43.2" description = "Mechanistic Interpretability Visualizations" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -707,7 +682,6 @@ triton = {version = "2.1.0", markers = "platform_system == \"Linux\" and platfor name = "click" version = "8.1.7" description = "Composable command line interface toolkit" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -722,7 +696,6 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." -category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -734,7 +707,6 @@ files = [ name = "comm" version = "0.2.1" description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -752,7 +724,6 @@ test = ["pytest"] name = "coverage" version = "7.4.0" description = "Code coverage measurement for Python" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -820,7 +791,6 @@ toml = ["tomli"] name = "datasets" version = "2.14.4" description = "HuggingFace community-driven open-source library of datasets" -category = "main" optional = false python-versions = ">=3.8.0" files = [ @@ -863,7 +833,6 @@ vision = ["Pillow (>=6.2.1)"] name = "debugpy" version = "1.8.0" description = "An implementation of the Debug Adapter Protocol for Python" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -891,7 +860,6 @@ files = [ name = "decorator" version = "5.1.1" description = "Decorators for Humans" -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -903,7 +871,6 @@ files = [ name = "defusedxml" version = "0.7.1" description = "XML bomb protection for Python stdlib modules" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -915,7 +882,6 @@ files = [ name = "dill" version = "0.3.7" description = "serialize all of Python" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -930,7 +896,6 @@ graph = ["objgraph (>=1.7.2)"] name = "docker-pycreds" version = "0.4.0" description = "Python bindings for the docker credentials store API" -category = "main" optional = false python-versions = "*" files = [ @@ -945,7 +910,6 @@ six = ">=1.4.0" name = "docutils" version = "0.19" description = "Docutils -- Python Documentation Utilities" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -957,7 +921,6 @@ files = [ name = "einops" version = "0.7.0" description = "A new flavour of deep learning operations" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -969,7 +932,6 @@ files = [ name = "exceptiongroup" version = "1.2.0" description = "Backport of PEP 654 (exception groups)" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -984,7 +946,6 @@ test = ["pytest (>=6)"] name = "executing" version = "2.0.1" description = "Get the currently executing AST node of a frame, and other information" -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -999,7 +960,6 @@ tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipyth name = "fancy-einsum" version = "0.0.3" description = "Drop-in replacement for torch/numpy einsum, with descriptive variable names in equations" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -1011,7 +971,6 @@ files = [ name = "fastjsonschema" version = "2.19.1" description = "Fastest Python implementation of JSON schema" -category = "dev" optional = false python-versions = "*" files = [ @@ -1026,7 +985,6 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc name = "filelock" version = "3.13.1" description = "A platform independent file lock." -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1043,7 +1001,6 @@ typing = ["typing-extensions (>=4.8)"] name = "fqdn" version = "1.5.1" description = "Validates fully-qualified domain names against RFC 1123, so that they are acceptable to modern bowsers" -category = "dev" optional = false python-versions = ">=2.7, !=3.0, !=3.1, !=3.2, !=3.3, !=3.4, <4" files = [ @@ -1055,7 +1012,6 @@ files = [ name = "frozenlist" version = "1.4.1" description = "A list-like structure which implements collections.abc.MutableSequence" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1142,7 +1098,6 @@ files = [ name = "fsspec" version = "2023.12.2" description = "File-system specification" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1182,7 +1137,6 @@ tqdm = ["tqdm"] name = "furo" version = "2023.3.27" description = "A clean customisable Sphinx documentation theme." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1200,7 +1154,6 @@ sphinx-basic-ng = "*" name = "gitdb" version = "4.0.11" description = "Git Object Database" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1215,7 +1168,6 @@ smmap = ">=3.0.1,<6" name = "gitpython" version = "3.1.40" description = "GitPython is a Python library used to interact with Git repositories" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1233,7 +1185,6 @@ test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre name = "huggingface-hub" version = "0.20.2" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" -category = "main" optional = false python-versions = ">=3.8.0" files = [ @@ -1266,7 +1217,6 @@ typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "t name = "idna" version = "3.6" description = "Internationalized Domain Names in Applications (IDNA)" -category = "main" optional = false python-versions = ">=3.5" files = [ @@ -1278,7 +1228,6 @@ files = [ name = "imagesize" version = "1.4.1" description = "Getting image size from png/jpeg/jpeg2000/gif file" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -1290,7 +1239,6 @@ files = [ name = "importlib-metadata" version = "7.0.1" description = "Read metadata from Python packages" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -1310,7 +1258,6 @@ testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs name = "importlib-resources" version = "6.1.1" description = "Read resources from Python packages" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1329,7 +1276,6 @@ testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1341,7 +1287,6 @@ files = [ name = "ipykernel" version = "6.28.0" description = "IPython Kernel for Jupyter" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1355,7 +1300,7 @@ comm = ">=0.1.1" debugpy = ">=1.6.5" ipython = ">=7.23.1" jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" matplotlib-inline = ">=0.1" nest-asyncio = "*" packaging = "*" @@ -1375,7 +1320,6 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio" name = "ipython" version = "8.12.3" description = "IPython: Productive Interactive Computing" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1415,7 +1359,6 @@ test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.21)", "pa name = "ipywidgets" version = "8.1.1" description = "Jupyter interactive widgets" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1437,7 +1380,6 @@ test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"] name = "isoduration" version = "20.11.0" description = "Operations with ISO 8601 durations" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1452,7 +1394,6 @@ arrow = ">=0.15.0" name = "isort" version = "5.8.0" description = "A Python utility / library to sort Python imports." -category = "dev" optional = false python-versions = ">=3.6,<4.0" files = [ @@ -1469,7 +1410,6 @@ requirements-deprecated-finder = ["pip-api", "pipreqs"] name = "jaxtyping" version = "0.2.19" description = "Type annotations and runtime checking for shape and dtype of JAX arrays, and PyTrees." -category = "main" optional = false python-versions = "~=3.8" files = [ @@ -1486,7 +1426,6 @@ typing-extensions = ">=3.7.4.1" name = "jedi" version = "0.19.1" description = "An autocompletion tool for Python that can be used for text editors." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -1506,7 +1445,6 @@ testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] name = "jinja2" version = "3.1.2" description = "A very fast and expressive template engine." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1524,7 +1462,6 @@ i18n = ["Babel (>=2.7)"] name = "json5" version = "0.9.14" description = "A Python implementation of the JSON5 data format." -category = "dev" optional = false python-versions = "*" files = [ @@ -1539,7 +1476,6 @@ dev = ["hypothesis"] name = "jsonpointer" version = "2.4" description = "Identify specific nodes in a JSON document (RFC 6901)" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ @@ -1551,7 +1487,6 @@ files = [ name = "jsonschema" version = "4.20.0" description = "An implementation of JSON Schema validation for Python" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1583,7 +1518,6 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- name = "jsonschema-specifications" version = "2023.12.1" description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1599,7 +1533,6 @@ referencing = ">=0.31.0" name = "jupyter" version = "1.0.0" description = "Jupyter metapackage. Install all the Jupyter components in one go." -category = "dev" optional = false python-versions = "*" files = [ @@ -1620,7 +1553,6 @@ qtconsole = "*" name = "jupyter-client" version = "8.6.0" description = "Jupyter protocol implementation and client libraries" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1630,7 +1562,7 @@ files = [ [package.dependencies] importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} -jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" python-dateutil = ">=2.8.2" pyzmq = ">=23.0" tornado = ">=6.2" @@ -1644,7 +1576,6 @@ test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pyt name = "jupyter-console" version = "6.6.3" description = "Jupyter terminal console" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1656,7 +1587,7 @@ files = [ ipykernel = ">=6.14" ipython = "*" jupyter-client = ">=7.0.0" -jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" prompt-toolkit = ">=3.0.30" pygments = "*" pyzmq = ">=17" @@ -1669,7 +1600,6 @@ test = ["flaky", "pexpect", "pytest"] name = "jupyter-core" version = "5.7.1" description = "Jupyter core package. A base package on which Jupyter projects rely." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1690,7 +1620,6 @@ test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] name = "jupyter-events" version = "0.9.0" description = "Jupyter Event System library" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1716,7 +1645,6 @@ test = ["click", "pre-commit", "pytest (>=7.0)", "pytest-asyncio (>=0.19.0)", "p name = "jupyter-lsp" version = "2.2.1" description = "Multi-Language Server WebSocket proxy for Jupyter Notebook/Lab server" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1732,7 +1660,6 @@ jupyter-server = ">=1.1.2" name = "jupyter-server" version = "2.12.3" description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1745,7 +1672,7 @@ anyio = ">=3.1.0" argon2-cffi = "*" jinja2 = "*" jupyter-client = ">=7.4.4" -jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" jupyter-events = ">=0.9.0" jupyter-server-terminals = "*" nbconvert = ">=6.4.4" @@ -1769,7 +1696,6 @@ test = ["flaky", "ipykernel", "pre-commit", "pytest (>=7.0)", "pytest-console-sc name = "jupyter-server-terminals" version = "0.5.1" description = "A Jupyter Server Extension Providing Terminals." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1789,7 +1715,6 @@ test = ["jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-jupyter[server] (> name = "jupyterlab" version = "4.0.10" description = "JupyterLab computational environment" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1823,7 +1748,6 @@ test = ["coverage", "pytest (>=7.0)", "pytest-check-links (>=0.7)", "pytest-cons name = "jupyterlab-pygments" version = "0.3.0" description = "Pygments theme using JupyterLab CSS variables" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1835,7 +1759,6 @@ files = [ name = "jupyterlab-server" version = "2.25.2" description = "A set of server components for JupyterLab and JupyterLab like applications." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1862,7 +1785,6 @@ test = ["hatch", "ipykernel", "openapi-core (>=0.18.0,<0.19.0)", "openapi-spec-v name = "jupyterlab-widgets" version = "3.0.9" description = "Jupyter interactive widgets for JupyterLab" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1874,7 +1796,6 @@ files = [ name = "libcst" version = "1.1.0" description = "A concrete syntax tree with AST-like properties for Python 3.5, 3.6, 3.7, 3.8, 3.9, and 3.10 programs." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1923,7 +1844,6 @@ dev = ["Sphinx (>=5.1.1)", "black (==23.9.1)", "build (>=0.10.0)", "coverage (>= name = "livereload" version = "2.6.3" description = "Python LiveReload is an awesome tool for web developers" -category = "dev" optional = false python-versions = "*" files = [ @@ -1939,7 +1859,6 @@ tornado = {version = "*", markers = "python_version > \"2.7\""} name = "markdown-it-py" version = "2.2.0" description = "Python port of markdown-it. Markdown parsing, done right!" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1964,7 +1883,6 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] name = "markupsafe" version = "2.1.3" description = "Safely add untrusted strings to HTML/XML markup." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2034,7 +1952,6 @@ files = [ name = "matplotlib-inline" version = "0.1.6" description = "Inline Matplotlib backend for Jupyter" -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2049,7 +1966,6 @@ traitlets = "*" name = "mdit-py-plugins" version = "0.3.5" description = "Collection of plugins for markdown-it-py" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2069,7 +1985,6 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] name = "mdurl" version = "0.1.2" description = "Markdown URL utilities" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2081,7 +1996,6 @@ files = [ name = "mistune" version = "3.0.2" description = "A sane and fast Markdown parser with useful plugins and renderers" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2093,7 +2007,6 @@ files = [ name = "mpmath" version = "1.3.0" description = "Python library for arbitrary-precision floating-point arithmetic" -category = "main" optional = false python-versions = "*" files = [ @@ -2111,7 +2024,6 @@ tests = ["pytest (>=4.6)"] name = "multidict" version = "6.0.4" description = "multidict implementation" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2195,7 +2107,6 @@ files = [ name = "multiprocess" version = "0.70.15" description = "better multiprocessing and multithreading in Python" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2224,7 +2135,6 @@ dill = ">=0.3.7" name = "mypy" version = "1.8.0" description = "Optional static typing for Python" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2272,7 +2182,6 @@ reports = ["lxml"] name = "mypy-extensions" version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2284,7 +2193,6 @@ files = [ name = "myst-parser" version = "1.0.0" description = "An extended [CommonMark](https://spec.commonmark.org/) compliant parser," -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2311,7 +2219,6 @@ testing-docutils = ["pygments", "pytest (>=7,<8)", "pytest-param-files (>=0.3.4, name = "nbclient" version = "0.9.0" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." -category = "dev" optional = false python-versions = ">=3.8.0" files = [ @@ -2321,7 +2228,7 @@ files = [ [package.dependencies] jupyter-client = ">=6.1.12" -jupyter-core = ">=4.12,<5.0.0 || >=5.1.0" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" nbformat = ">=5.1" traitlets = ">=5.4" @@ -2334,7 +2241,6 @@ test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>= name = "nbconvert" version = "7.14.0" description = "Converting Jupyter Notebooks" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2373,7 +2279,6 @@ webpdf = ["playwright"] name = "nbformat" version = "5.9.2" description = "The Jupyter Notebook format" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2395,7 +2300,6 @@ test = ["pep440", "pre-commit", "pytest", "testpath"] name = "nbsphinx" version = "0.9.3" description = "Jupyter Notebook Tools for Sphinx" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2415,7 +2319,6 @@ traitlets = ">=5" name = "nbval" version = "0.10.0" description = "A py.test plugin to validate Jupyter notebooks" -category = "dev" optional = false python-versions = ">=3.6, <4" files = [ @@ -2434,7 +2337,6 @@ pytest = ">=2.8" name = "nest-asyncio" version = "1.5.8" description = "Patch asyncio to allow nested event loops" -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -2446,7 +2348,6 @@ files = [ name = "networkx" version = "3.1" description = "Python package for creating and manipulating graphs and networks" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2465,7 +2366,6 @@ test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] name = "notebook" version = "7.0.6" description = "Jupyter Notebook - A web-based notebook environment for interactive computing" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2489,7 +2389,6 @@ test = ["importlib-resources (>=5.0)", "ipykernel", "jupyter-server[test] (>=2.4 name = "notebook-shim" version = "0.2.3" description = "A shim layer for notebook traits and config" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -2507,7 +2406,6 @@ test = ["pytest", "pytest-console-scripts", "pytest-jupyter", "pytest-tornasync" name = "numpy" version = "1.24.4" description = "Fundamental package for array computing in Python" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2545,7 +2443,6 @@ files = [ name = "numpy" version = "1.26.3" description = "Fundamental package for array computing in Python" -category = "main" optional = false python-versions = ">=3.9" files = [ @@ -2591,7 +2488,6 @@ files = [ name = "nvidia-cublas-cu12" version = "12.1.3.1" description = "CUBLAS native runtime libraries" -category = "main" optional = false python-versions = ">=3" files = [ @@ -2603,7 +2499,6 @@ files = [ name = "nvidia-cuda-cupti-cu12" version = "12.1.105" description = "CUDA profiling tools runtime libs." -category = "main" optional = false python-versions = ">=3" files = [ @@ -2615,7 +2510,6 @@ files = [ name = "nvidia-cuda-nvrtc-cu12" version = "12.1.105" description = "NVRTC native runtime libraries" -category = "main" optional = false python-versions = ">=3" files = [ @@ -2627,7 +2521,6 @@ files = [ name = "nvidia-cuda-runtime-cu12" version = "12.1.105" description = "CUDA Runtime native Libraries" -category = "main" optional = false python-versions = ">=3" files = [ @@ -2639,7 +2532,6 @@ files = [ name = "nvidia-cudnn-cu12" version = "8.9.2.26" description = "cuDNN runtime libraries" -category = "main" optional = false python-versions = ">=3" files = [ @@ -2653,7 +2545,6 @@ nvidia-cublas-cu12 = "*" name = "nvidia-cufft-cu12" version = "11.0.2.54" description = "CUFFT native runtime libraries" -category = "main" optional = false python-versions = ">=3" files = [ @@ -2665,7 +2556,6 @@ files = [ name = "nvidia-curand-cu12" version = "10.3.2.106" description = "CURAND native runtime libraries" -category = "main" optional = false python-versions = ">=3" files = [ @@ -2677,7 +2567,6 @@ files = [ name = "nvidia-cusolver-cu12" version = "11.4.5.107" description = "CUDA solver native runtime libraries" -category = "main" optional = false python-versions = ">=3" files = [ @@ -2694,7 +2583,6 @@ nvidia-nvjitlink-cu12 = "*" name = "nvidia-cusparse-cu12" version = "12.1.0.106" description = "CUSPARSE native runtime libraries" -category = "main" optional = false python-versions = ">=3" files = [ @@ -2709,7 +2597,6 @@ nvidia-nvjitlink-cu12 = "*" name = "nvidia-nccl-cu12" version = "2.18.1" description = "NVIDIA Collective Communication Library (NCCL) Runtime" -category = "main" optional = false python-versions = ">=3" files = [ @@ -2720,7 +2607,6 @@ files = [ name = "nvidia-nvjitlink-cu12" version = "12.3.101" description = "Nvidia JIT LTO Library" -category = "main" optional = false python-versions = ">=3" files = [ @@ -2732,7 +2618,6 @@ files = [ name = "nvidia-nvtx-cu12" version = "12.1.105" description = "NVIDIA Tools Extension" -category = "main" optional = false python-versions = ">=3" files = [ @@ -2744,7 +2629,6 @@ files = [ name = "overrides" version = "7.4.0" description = "A decorator to automatically detect mismatch when overriding a method." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2756,7 +2640,6 @@ files = [ name = "packaging" version = "23.2" description = "Core utilities for Python packages" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2768,7 +2651,6 @@ files = [ name = "pandas" version = "2.0.3" description = "Powerful data structures for data analysis, time series, and statistics" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -2802,8 +2684,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -2836,7 +2718,6 @@ xml = ["lxml (>=4.6.3)"] name = "pandoc" version = "2.3" description = "Pandoc Documents for Python" -category = "dev" optional = false python-versions = "*" files = [ @@ -2851,7 +2732,6 @@ ply = "*" name = "pandocfilters" version = "1.5.0" description = "Utilities for writing pandoc filters in python" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -2863,7 +2743,6 @@ files = [ name = "parso" version = "0.8.3" description = "A Python Parser" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2879,7 +2758,6 @@ testing = ["docopt", "pytest (<6.0.0)"] name = "pathspec" version = "0.12.1" description = "Utility library for gitignore style pattern matching of file paths." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2891,7 +2769,6 @@ files = [ name = "pexpect" version = "4.9.0" description = "Pexpect allows easy control of interactive console applications." -category = "dev" optional = false python-versions = "*" files = [ @@ -2906,7 +2783,6 @@ ptyprocess = ">=0.5" name = "pickleshare" version = "0.7.5" description = "Tiny 'shelve'-like database with concurrency support" -category = "dev" optional = false python-versions = "*" files = [ @@ -2918,7 +2794,6 @@ files = [ name = "pkgutil-resolve-name" version = "1.3.10" description = "Resolve a name to an object." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2930,7 +2805,6 @@ files = [ name = "platformdirs" version = "4.1.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2946,7 +2820,6 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-co name = "plotly" version = "5.18.0" description = "An open-source, interactive data visualization library for Python" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2962,7 +2835,6 @@ tenacity = ">=6.2.0" name = "pluggy" version = "1.3.0" description = "plugin and hook calling mechanisms for python" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -2978,7 +2850,6 @@ testing = ["pytest", "pytest-benchmark"] name = "plumbum" version = "1.8.2" description = "Plumbum: shell combinators library" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2998,7 +2869,6 @@ ssh = ["paramiko"] name = "ply" version = "3.11" description = "Python Lex & Yacc" -category = "dev" optional = false python-versions = "*" files = [ @@ -3010,7 +2880,6 @@ files = [ name = "pockets" version = "0.9.1" description = "A collection of helpful Python tools!" -category = "dev" optional = false python-versions = "*" files = [ @@ -3025,7 +2894,6 @@ six = ">=1.5.2" name = "prometheus-client" version = "0.19.0" description = "Python client for the Prometheus monitoring system." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3040,7 +2908,6 @@ twisted = ["twisted"] name = "prompt-toolkit" version = "3.0.43" description = "Library for building powerful interactive command lines in Python" -category = "dev" optional = false python-versions = ">=3.7.0" files = [ @@ -3055,7 +2922,6 @@ wcwidth = "*" name = "protobuf" version = "4.25.1" description = "" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -3076,7 +2942,6 @@ files = [ name = "psutil" version = "5.9.7" description = "Cross-platform lib for process and system monitoring in Python." -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -3105,7 +2970,6 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] name = "ptyprocess" version = "0.7.0" description = "Run a subprocess in a pseudo terminal" -category = "dev" optional = false python-versions = "*" files = [ @@ -3117,7 +2981,6 @@ files = [ name = "pure-eval" version = "0.2.2" description = "Safely evaluate AST nodes without side effects" -category = "dev" optional = false python-versions = "*" files = [ @@ -3132,7 +2995,6 @@ tests = ["pytest"] name = "pyarrow" version = "14.0.2" description = "Python library for Apache Arrow" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -3181,7 +3043,6 @@ numpy = ">=1.16.6" name = "pycln" version = "2.4.0" description = "A formatter for finding and removing unused import statements." -category = "dev" optional = false python-versions = ">=3.7.0,<4" files = [ @@ -3200,7 +3061,6 @@ typer = ">=0.4.1" name = "pycparser" version = "2.21" description = "C parser in Python" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3212,7 +3072,6 @@ files = [ name = "pygments" version = "2.17.2" description = "Pygments is a syntax highlighting package written in Python." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3228,7 +3087,6 @@ windows-terminal = ["colorama (>=0.4.6)"] name = "pytest" version = "7.4.4" description = "pytest: simple powerful testing with Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3251,7 +3109,6 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no name = "pytest-cov" version = "4.1.0" description = "Pytest plugin for measuring coverage." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3270,7 +3127,6 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtuale name = "pytest-doctestplus" version = "1.1.0" description = "Pytest plugin with advanced doctest features." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3290,7 +3146,6 @@ test = ["numpy", "pytest-remotedata (>=0.3.2)", "sphinx"] name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" -category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ @@ -3305,7 +3160,6 @@ six = ">=1.5" name = "python-json-logger" version = "2.0.7" description = "A python library adding a json log formatter" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3317,7 +3171,6 @@ files = [ name = "pytz" version = "2023.3.post1" description = "World timezone definitions, modern and historical" -category = "main" optional = false python-versions = "*" files = [ @@ -3329,7 +3182,6 @@ files = [ name = "pywin32" version = "306" description = "Python for Window Extensions" -category = "dev" optional = false python-versions = "*" files = [ @@ -3353,7 +3205,6 @@ files = [ name = "pywinpty" version = "2.0.12" description = "Pseudo terminal support for Windows from Python." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3369,7 +3220,6 @@ files = [ name = "pyyaml" version = "6.0.1" description = "YAML parser and emitter for Python" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3391,6 +3241,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -3429,7 +3280,6 @@ files = [ name = "pyzmq" version = "25.1.2" description = "Python bindings for 0MQ" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3535,7 +3385,6 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} name = "qtconsole" version = "5.5.1" description = "Jupyter Qt console" -category = "dev" optional = false python-versions = ">= 3.8" files = [ @@ -3561,7 +3410,6 @@ test = ["flaky", "pytest", "pytest-qt"] name = "qtpy" version = "2.4.1" description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5/6 and PySide2/6)." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3579,7 +3427,6 @@ test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"] name = "referencing" version = "0.32.1" description = "JSON Referencing + Python" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3595,7 +3442,6 @@ rpds-py = ">=0.7.0" name = "regex" version = "2023.12.25" description = "Alternative regular expression module, to replace re." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3698,7 +3544,6 @@ files = [ name = "requests" version = "2.31.0" description = "Python HTTP for Humans." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3720,7 +3565,6 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] name = "rfc3339-validator" version = "0.1.4" description = "A pure python RFC3339 validator" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -3735,7 +3579,6 @@ six = "*" name = "rfc3986-validator" version = "0.1.1" description = "Pure python rfc3986 validator" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -3747,7 +3590,6 @@ files = [ name = "rich" version = "13.7.0" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" -category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -3767,7 +3609,6 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] name = "rpds-py" version = "0.16.2" description = "Python bindings to Rust's persistent data structures (rpds)" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -3876,7 +3717,6 @@ files = [ name = "safetensors" version = "0.4.1" description = "" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3996,7 +3836,6 @@ torch = ["safetensors[numpy]", "torch (>=1.10)"] name = "send2trash" version = "1.8.2" description = "Send file to trash natively under Mac OS X, Windows and Linux" -category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" files = [ @@ -4013,7 +3852,6 @@ win32 = ["pywin32"] name = "sentry-sdk" version = "1.39.1" description = "Python client for Sentry (https://sentry.io)" -category = "main" optional = false python-versions = "*" files = [ @@ -4059,7 +3897,6 @@ tornado = ["tornado (>=5)"] name = "setproctitle" version = "1.3.3" description = "A Python module to customize the process title" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4160,7 +3997,6 @@ test = ["pytest"] name = "setuptools" version = "69.0.3" description = "Easily download, build, install, upgrade, and uninstall Python packages" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -4177,7 +4013,6 @@ testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jar name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -4189,7 +4024,6 @@ files = [ name = "smmap" version = "5.0.1" description = "A pure Python implementation of a sliding window memory map manager" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4201,7 +4035,6 @@ files = [ name = "sniffio" version = "1.3.0" description = "Sniff out which async library your code is running under" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4213,7 +4046,6 @@ files = [ name = "snowballstemmer" version = "2.2.0" description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms." -category = "dev" optional = false python-versions = "*" files = [ @@ -4225,7 +4057,6 @@ files = [ name = "soupsieve" version = "2.5" description = "A modern CSS selector implementation for Beautiful Soup." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -4237,7 +4068,6 @@ files = [ name = "sphinx" version = "5.2.3" description = "Python documentation generator" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4273,7 +4103,6 @@ test = ["cython", "html5lib", "pytest (>=4.6)", "typed_ast"] name = "sphinx-autobuild" version = "2021.3.14" description = "Rebuild Sphinx documentation on changes, with live-reload in the browser." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4293,7 +4122,6 @@ test = ["pytest", "pytest-cov"] name = "sphinx-basic-ng" version = "1.0.0b2" description = "A modern skeleton for Sphinx themes." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4311,7 +4139,6 @@ docs = ["furo", "ipython", "myst-parser", "sphinx-copybutton", "sphinx-inline-ta name = "sphinxcontrib-applehelp" version = "1.0.4" description = "sphinxcontrib-applehelp is a Sphinx extension which outputs Apple help books" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -4327,7 +4154,6 @@ test = ["pytest"] name = "sphinxcontrib-devhelp" version = "1.0.2" description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp document." -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -4343,7 +4169,6 @@ test = ["pytest"] name = "sphinxcontrib-htmlhelp" version = "2.0.1" description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -4359,7 +4184,6 @@ test = ["html5lib", "pytest"] name = "sphinxcontrib-jsmath" version = "1.0.1" description = "A sphinx extension which renders display math in HTML via JavaScript" -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -4374,7 +4198,6 @@ test = ["flake8", "mypy", "pytest"] name = "sphinxcontrib-napoleon" version = "0.7" description = "Sphinx \"napoleon\" extension." -category = "dev" optional = false python-versions = "*" files = [ @@ -4390,7 +4213,6 @@ six = ">=1.5.2" name = "sphinxcontrib-qthelp" version = "1.0.3" description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp document." -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -4406,7 +4228,6 @@ test = ["pytest"] name = "sphinxcontrib-serializinghtml" version = "1.1.5" description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)." -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -4422,7 +4243,6 @@ test = ["pytest"] name = "stack-data" version = "0.6.3" description = "Extract data from python stack frames and tracebacks for informative displays" -category = "dev" optional = false python-versions = "*" files = [ @@ -4442,7 +4262,6 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] name = "sympy" version = "1.12" description = "Computer algebra system (CAS) in Python" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -4457,7 +4276,6 @@ mpmath = ">=0.19" name = "tabulate" version = "0.9.0" description = "Pretty-print tabular data" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4472,7 +4290,6 @@ widechars = ["wcwidth"] name = "tenacity" version = "8.2.3" description = "Retry code until it succeeds" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4487,7 +4304,6 @@ doc = ["reno", "sphinx", "tornado (>=4.5)"] name = "terminado" version = "0.18.0" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -4509,7 +4325,6 @@ typing = ["mypy (>=1.6,<2.0)", "traitlets (>=5.11.1)"] name = "tinycss2" version = "1.2.1" description = "A tiny CSS parser" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4528,7 +4343,6 @@ test = ["flake8", "isort", "pytest"] name = "tokenizers" version = "0.15.0" description = "" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4644,7 +4458,6 @@ testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] name = "tomli" version = "2.0.1" description = "A lil' TOML parser" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4656,7 +4469,6 @@ files = [ name = "tomlkit" version = "0.12.3" description = "Style preserving TOML library" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4668,7 +4480,6 @@ files = [ name = "torch" version = "2.1.2" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" -category = "main" optional = false python-versions = ">=3.8.0" files = [ @@ -4722,7 +4533,6 @@ opt-einsum = ["opt-einsum (>=3.3)"] name = "tornado" version = "6.4" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." -category = "dev" optional = false python-versions = ">= 3.8" files = [ @@ -4743,7 +4553,6 @@ files = [ name = "tqdm" version = "4.66.1" description = "Fast, Extensible Progress Meter" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4764,7 +4573,6 @@ telegram = ["requests"] name = "traitlets" version = "5.14.1" description = "Traitlets Python configuration system" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -4778,14 +4586,13 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0, [[package]] name = "transformers" -version = "4.36.2" +version = "4.37.2" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" -category = "main" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.36.2-py3-none-any.whl", hash = "sha256:462066c4f74ee52516f12890dcc9ec71d1a5e97998db621668455117a54330f6"}, - {file = "transformers-4.36.2.tar.gz", hash = "sha256:d8068e897e47793281501e547d2bbdfc5b8556409c2cb6c3d9e2ca77d4c0b4ec"}, + {file = "transformers-4.37.2-py3-none-any.whl", hash = "sha256:595a8b12a1fcc4ad0ced49ce206c58e17be68c85d7aee3d7546d04a32c910d2e"}, + {file = "transformers-4.37.2.tar.gz", hash = "sha256:f307082ae5d528b8480611a4879a4a11651012d0e9aaea3f6cf17219ffd95542"}, ] [package.dependencies] @@ -4796,22 +4603,22 @@ packaging = ">=20.0" pyyaml = ">=5.1" regex = "!=2019.12.17" requests = "*" -safetensors = ">=0.3.1" +safetensors = ">=0.4.1" tokenizers = ">=0.14,<0.19" tqdm = ">=4.27" [package.extras] accelerate = ["accelerate (>=0.21.0)"] -agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.10,!=1.12.0)"] -all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] +agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.11,!=1.12.0)"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.11,!=1.12.0)", "torchaudio", "torchvision"] audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] codecarbon = ["codecarbon (==1.2.0)"] deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"] deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.11,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.19)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -docs = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.11,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +docs = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.11,!=1.12.0)", "torchaudio", "torchvision"] docs-specific = ["hf-doc-builder"] flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] @@ -4819,7 +4626,7 @@ ftfy = ["ftfy"] integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] modelcreation = ["cookiecutter (==1.7.3)"] -natten = ["natten (>=0.14.6)"] +natten = ["natten (>=0.14.6,<0.15.0)"] onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] optuna = ["optuna"] @@ -4838,10 +4645,10 @@ tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6, tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] timm = ["timm"] tokenizers = ["tokenizers (>=0.14,<0.19)"] -torch = ["accelerate (>=0.21.0)", "torch (>=1.10,!=1.12.0)"] +torch = ["accelerate (>=0.21.0)", "torch (>=1.11,!=1.12.0)"] torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "tqdm (>=4.27)"] +torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.19)", "torch (>=1.11,!=1.12.0)", "tqdm (>=4.27)"] video = ["av (==9.2.0)", "decord (==0.6.0)"] vision = ["Pillow (>=10.0.1,<=15.0)"] @@ -4849,7 +4656,6 @@ vision = ["Pillow (>=10.0.1,<=15.0)"] name = "triton" version = "2.1.0" description = "A language and compiler for custom Deep Learning operations" -category = "main" optional = false python-versions = "*" files = [ @@ -4875,7 +4681,6 @@ tutorials = ["matplotlib", "pandas", "tabulate"] name = "typeguard" version = "4.1.5" description = "Run-time type checker for Python" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -4895,7 +4700,6 @@ test = ["coverage[toml] (>=7)", "mypy (>=1.2.0)", "pytest (>=7)"] name = "typer" version = "0.9.0" description = "Typer, build great CLIs. Easy to code. Based on Python type hints." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -4917,7 +4721,6 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6. name = "types-python-dateutil" version = "2.8.19.20240106" description = "Typing stubs for python-dateutil" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -4929,7 +4732,6 @@ files = [ name = "typing-extensions" version = "4.9.0" description = "Backported and Experimental Type Hints for Python 3.8+" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -4941,7 +4743,6 @@ files = [ name = "typing-inspect" version = "0.9.0" description = "Runtime inspection utilities for typing module." -category = "dev" optional = false python-versions = "*" files = [ @@ -4957,7 +4758,6 @@ typing-extensions = ">=3.7.4" name = "tzdata" version = "2023.4" description = "Provider of IANA time zone data" -category = "main" optional = false python-versions = ">=2" files = [ @@ -4969,7 +4769,6 @@ files = [ name = "uri-template" version = "1.3.0" description = "RFC 6570 URI Template Processor" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4984,7 +4783,6 @@ dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake name = "urllib3" version = "2.1.0" description = "HTTP library with thread-safe connection pooling, file post, and more." -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -5001,7 +4799,6 @@ zstd = ["zstandard (>=0.18.0)"] name = "wandb" version = "0.16.2" description = "A CLI and library for interacting with the Weights & Biases API." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5043,7 +4840,6 @@ sweeps = ["sweeps (>=0.2.0)"] name = "wcwidth" version = "0.2.13" description = "Measures the displayed width of unicode strings in a terminal" -category = "dev" optional = false python-versions = "*" files = [ @@ -5055,7 +4851,6 @@ files = [ name = "webcolors" version = "1.13" description = "A library for working with the color formats defined by HTML and CSS." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -5071,7 +4866,6 @@ tests = ["pytest", "pytest-cov"] name = "webencodings" version = "0.5.1" description = "Character encoding aliases for legacy web content" -category = "dev" optional = false python-versions = "*" files = [ @@ -5083,7 +4877,6 @@ files = [ name = "websocket-client" version = "1.7.0" description = "WebSocket client for Python with low level API options" -category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -5100,7 +4893,6 @@ test = ["websockets"] name = "widgetsnbextension" version = "4.0.9" description = "Jupyter interactive widgets for Jupyter Notebook" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -5112,7 +4904,6 @@ files = [ name = "xxhash" version = "3.4.1" description = "Python binding for xxHash" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5230,7 +5021,6 @@ files = [ name = "yarl" version = "1.9.4" description = "Yet another URL library" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5334,7 +5124,6 @@ multidict = ">=4.0" name = "zipp" version = "3.17.0" description = "Backport of pathlib-compatible object wrapper for zip files" -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -5348,5 +5137,5 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" - python-versions = ">=3.8,<4.0" -content-hash = "44b4da5ea68927793614a3b2f05fb9ead790d8a4b506c240b17d19b20cbe7cee" +python-versions = ">=3.8,<4.0" +content-hash = "1ef3e46351ab989160cd31387ea5dcc887ba643de8a6f5329ffe0dbbbf16fdc7" diff --git a/pyproject.toml b/pyproject.toml index f43a10572..dd74eddc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ rich=">=12.6.0" torch=">=1.10,!=2.0,!=2.1.0" # Pin >=2.1.1 due to known MPS errors on 2.1.0 tqdm=">=4.64.1" - transformers=">=4.34" + transformers=">=4.37.2" typing-extensions="*" wandb=">=0.13.5" better-abc = "^0.0.3" diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index d2bbf6f08..3a0d06b5d 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -153,6 +153,16 @@ "Qwen/Qwen-1_8B-Chat", "Qwen/Qwen-7B-Chat", "Qwen/Qwen-14B-Chat", + "Qwen/Qwen1.5-0.5B", + "Qwen/Qwen1.5-0.5B-Chat", + "Qwen/Qwen1.5-1.8B", + "Qwen/Qwen1.5-1.8B-Chat", + "Qwen/Qwen1.5-4B", + "Qwen/Qwen1.5-4B-Chat", + "Qwen/Qwen1.5-7B", + "Qwen/Qwen1.5-7B-Chat", + "Qwen/Qwen1.5-14B", + "Qwen/Qwen1.5-14B-Chat", "microsoft/phi-1", "microsoft/phi-1_5", "microsoft/phi-2", @@ -543,6 +553,16 @@ "Qwen/Qwen-1_8B-Chat": ["qwen-1.8b-chat"], "Qwen/Qwen-7B-Chat": ["qwen-7b-chat"], "Qwen/Qwen-14B-Chat": ["qwen-14b-chat"], + "Qwen/Qwen1.5-0.5B": ["qwen1.5-0.5b"], + "Qwen/Qwen1.5-0.5B-Chat": ["qwen1.5-0.5b-chat"], + "Qwen/Qwen1.5-1.8B": ["qwen1.5-1.8b"], + "Qwen/Qwen1.5-1.8B-Chat": ["qwen1.5-1.8b-chat"], + "Qwen/Qwen1.5-4B": ["qwen1.5-4b"], + "Qwen/Qwen1.5-4B-Chat": ["qwen1.5-4b-chat"], + "Qwen/Qwen1.5-7B": ["qwen1.5-7b"], + "Qwen/Qwen1.5-7B-Chat": ["qwen1.5-7b-chat"], + "Qwen/Qwen1.5-14B": ["qwen1.5-14b"], + "Qwen/Qwen1.5-14B-Chat": ["qwen1.5-14b-chat"], "microsoft/phi-1": ["phi-1"], "microsoft/phi-1_5": ["phi-1_5"], "microsoft/phi-2": ["phi-2"], @@ -896,6 +916,29 @@ def convert_hf_model_config(model_name: str, **kwargs): "final_rms": True, "gated_mlp": True, } + elif architecture == "Qwen2ForCausalLM": + # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM. + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big + "eps": hf_config.rms_norm_eps, + "d_vocab": hf_config.vocab_size, + "act_fn": hf_config.hidden_act, + "use_attn_scale": True, + "initializer_range": hf_config.initializer_range, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "rotary_base": hf_config.rope_theta, + "rotary_adjacent_pairs": False, + "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, + "tokenizer_prepends_bos": True, + "final_rms": True, + "gated_mlp": True, + } elif architecture == "PhiForCausalLM": # Architecture for microsoft/phi models cfg_dict = { @@ -1265,6 +1308,8 @@ def get_pretrained_state_dict( state_dict = convert_coder_weights(hf_model, cfg) elif cfg.original_architecture == "QWenLMHeadModel": state_dict = convert_qwen_weights(hf_model, cfg) + elif cfg.original_architecture == "Qwen2ForCausalLM": + state_dict = convert_qwen2_weights(hf_model, cfg) elif cfg.original_architecture == "PhiForCausalLM": state_dict = convert_phi_weights(hf_model, cfg) else: @@ -1665,6 +1710,82 @@ def convert_qwen_weights(qwen, cfg: HookedTransformerConfig): return state_dict +def convert_qwen2_weights(qwen, cfg: HookedTransformerConfig): + # Note that this method is also applied for Qwen1.5 models, since they + # have architecture type Qwen2ForCausalLM. + + state_dict = {} + + state_dict["embed.W_E"] = qwen.model.embed_tokens.weight + + for l in range(cfg.n_layers): + state_dict[f"blocks.{l}.ln1.w"] = qwen.model.layers[l].input_layernorm.weight + + W_Q = qwen.model.layers[l].self_attn.q_proj.weight + W_K = qwen.model.layers[l].self_attn.k_proj.weight + W_V = qwen.model.layers[l].self_attn.v_proj.weight + W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_heads) + W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_heads) + + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn.W_K"] = W_K + state_dict[f"blocks.{l}.attn.W_V"] = W_V + + b_Q = qwen.model.layers[l].self_attn.q_proj.bias + b_Q = einops.rearrange( + b_Q, + "(n_head d_head) -> n_head d_head", + n_head=cfg.n_heads, + ) + + b_K = qwen.model.layers[l].self_attn.k_proj.bias + b_K = einops.rearrange( + b_K, + "(n_head d_head) -> n_head d_head", + n_head=cfg.n_heads, + ) + + b_V = qwen.model.layers[l].self_attn.v_proj.bias + b_V = einops.rearrange( + b_V, + "(n_head d_head) -> n_head d_head", + n_head=cfg.n_heads, + ) + + state_dict[f"blocks.{l}.attn.b_Q"] = b_Q + state_dict[f"blocks.{l}.attn.b_K"] = b_K + state_dict[f"blocks.{l}.attn.b_V"] = b_V + + W_O = qwen.model.layers[l].self_attn.o_proj.weight + W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_O"] = W_O + + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.ln2.w"] = qwen.model.layers[ + l + ].post_attention_layernorm.weight + + state_dict[f"blocks.{l}.mlp.W_in"] = qwen.model.layers[l].mlp.up_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_gate"] = qwen.model.layers[ + l + ].mlp.gate_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.mlp.W_out"] = qwen.model.layers[ + l + ].mlp.down_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict["ln_final.w"] = qwen.model.norm.weight + + state_dict["unembed.W_U"] = qwen.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) + + return state_dict + + def convert_mistral_weights(mistral, cfg: HookedTransformerConfig): state_dict = {} From 6673d8811ab7980ca0509674930b2a9cb43a0f0e Mon Sep 17 00:00:00 2001 From: cmathw <108584265+cmathw@users.noreply.github.com> Date: Thu, 14 Mar 2024 22:49:17 +0000 Subject: [PATCH 37/73] Support Gemma Models (#511) * add cfg dict and weights converter * norm embedding and add ones dtype * formatting * add gemma to acceptance tests * rename post_embedding_norm to post_embedding_scale * add batch dim to to_str_tok if tokenizer is gemma * add instruct-tuned models * hardcode gemma configs * add convert_gemma_weights to ignored functions in docs conf file * add architecture name * scale embedding weights when converting weights from HF to TL * remove redundant config comment * remove redundant arg * remove post_embedding_scale from config definition --- docs/source/conf.py | 1 + tests/acceptance/test_hooked_transformer.py | 2 + transformer_lens/HookedTransformer.py | 4 + transformer_lens/loading_from_pretrained.py | 126 ++++++++++++++++++++ 4 files changed, 133 insertions(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index 9308ea2b2..96d6b42e6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -86,6 +86,7 @@ "convert_neox_weights", "convert_neel_model_config", "convert_opt_weights", + "convert_gemma_weights", "fill_missing_keys", "get_basic_config", "get_official_model_name", diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index 0f1cbb438..7f3878161 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -36,6 +36,8 @@ "microsoft/phi-1", "microsoft/phi-1_5", "microsoft/phi-2", + "google/gemma-2b", + "google/gemma-7b", ] text = "Hello world!" """ diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 5dca72203..eb6fd6ba9 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -871,6 +871,9 @@ def to_str_tokens( tokens = self.to_tokens( input, prepend_bos=prepend_bos, padding_side=padding_side )[0] + # Gemma tokenizer expects a batch dimension + if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1: + tokens = tokens.unsqueeze(1) elif isinstance(input, torch.Tensor): tokens = input tokens = tokens.squeeze() # Get rid of a trivial batch dimension @@ -891,6 +894,7 @@ def to_str_tokens( ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" else: raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}") + str_tokens = self.tokenizer.batch_decode( tokens, clean_up_tokenization_spaces=False ) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 3a0d06b5d..4bad9c179 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -166,6 +166,10 @@ "microsoft/phi-1", "microsoft/phi-1_5", "microsoft/phi-2", + "google/gemma-2b", + "google/gemma-7b", + "google/gemma-2b-it", + "google/gemma-7b-it", ] """Official model names for models on HuggingFace.""" @@ -566,6 +570,10 @@ "microsoft/phi-1": ["phi-1"], "microsoft/phi-1_5": ["phi-1_5"], "microsoft/phi-2": ["phi-2"], + "google/gemma-2b": ["gemma-2b"], + "google/gemma-7b": ["gemma-7b"], + "google/gemma-2b-it": ["gemma-2b-it"], + "google/gemma-7b-it": ["gemma-7b-it"], } """Model aliases for models on HuggingFace.""" @@ -634,6 +642,8 @@ def convert_hf_model_config(model_name: str, **kwargs): architecture = "LlamaForCausalLM" elif "mistral" in official_model_name.lower(): architecture = "MistralForCausalLM" + elif "gemma" in official_model_name.lower(): + architecture = "GemmaForCausalLM" else: hf_config = AutoConfig.from_pretrained(official_model_name, **kwargs) architecture = hf_config.architectures[0] @@ -962,6 +972,50 @@ def convert_hf_model_config(model_name: str, **kwargs): partial_rotary_factor = hf_config.partial_rotary_factor cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"]) + elif official_model_name.startswith("google/gemma-2b"): + # Architecture for Gemma 2b and Gemma 2b Instruct models + cfg_dict = { + "d_model": 2048, + "d_head": 256, + "n_heads": 8, + "d_mlp": 16384, + "n_layers": 18, + "n_ctx": 8192, + "eps": 1e-06, + "d_vocab": 256000, + "act_fn": "gelu", + "initializer_range": 0.02, + "normalization_type": "RMS", + "rotary_base": 10000.0, + "rotary_dim": 256, + "positional_embedding_type": "rotary", + "use_attn_scale": True, + "n_key_value_heads": 1, + "gated_mlp": True, + "final_rms": True, + } + elif official_model_name.startswith("google/gemma-7b"): + # Architecture for Gemma 7b and Gemma 7b Instruct models + cfg_dict = { + "d_model": 3072, + "d_head": 256, + "n_heads": 16, + "d_mlp": 24576, + "n_layers": 28, + "n_ctx": 8192, + "eps": 1e-06, + "d_vocab": 256000, + "act_fn": "gelu", + "initializer_range": 0.02, + "normalization_type": "RMS", + "rotary_base": 10000.0, + "rotary_dim": 256, + "positional_embedding_type": "rotary", + "use_attn_scale": True, + "n_key_value_heads": 16, + "gated_mlp": True, + "final_rms": True, + } else: raise NotImplementedError(f"{architecture} is not currently supported.") # All of these models use LayerNorm @@ -1312,6 +1366,8 @@ def get_pretrained_state_dict( state_dict = convert_qwen2_weights(hf_model, cfg) elif cfg.original_architecture == "PhiForCausalLM": state_dict = convert_phi_weights(hf_model, cfg) + elif cfg.original_architecture == "GemmaForCausalLM": + state_dict = convert_gemma_weights(hf_model, cfg) else: raise ValueError( f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." @@ -2383,6 +2439,76 @@ def convert_phi_weights(phi, cfg: HookedTransformerConfig): return state_dict +def convert_gemma_weights(gemma, cfg: HookedTransformerConfig): + state_dict = {} + + # Gemma Models scale embeddings by multiplying by sqrt(d_model) + state_dict["embed.W_E"] = gemma.model.embed_tokens.weight * (cfg.d_model**0.5) + + # Gemma has no biases anywhere + for l in range(cfg.n_layers): + # GemmaRMSNorm adds 1 to weights before multiplying by input + state_dict[f"blocks.{l}.ln1.w"] = gemma.model.layers[ + l + ].input_layernorm.weight + torch.ones_like( + gemma.model.layers[l].input_layernorm.weight, dtype=cfg.dtype + ) + + W_Q = gemma.model.layers[l].self_attn.q_proj.weight + W_K = gemma.model.layers[l].self_attn.k_proj.weight + W_V = gemma.model.layers[l].self_attn.v_proj.weight + W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads) + W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads) + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn._W_K"] = W_K + state_dict[f"blocks.{l}.attn._W_V"] = W_V + + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( + cfg.n_heads, cfg.d_head, dtype=cfg.dtype + ) + state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( + cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype + ) + state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( + cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype + ) + + W_O = gemma.model.layers[l].self_attn.o_proj.weight + W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_O"] = W_O + + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + # GemmaRMSNorm adds 1 to weights before multiplying by input + state_dict[f"blocks.{l}.ln2.w"] = gemma.model.layers[ + l + ].post_attention_layernorm.weight + torch.ones_like( + gemma.model.norm.weight, dtype=cfg.dtype + ) + + state_dict[f"blocks.{l}.mlp.W_in"] = gemma.model.layers[l].mlp.up_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_gate"] = gemma.model.layers[ + l + ].mlp.gate_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.mlp.W_out"] = gemma.model.layers[ + l + ].mlp.down_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + # GemmaRMSNorm adds 1 to weights before multiplying by input + state_dict["ln_final.w"] = gemma.model.norm.weight + torch.ones_like( + gemma.model.norm.weight, dtype=cfg.dtype + ) + + state_dict["unembed.W_U"] = gemma.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) + + return state_dict + + @dataclasses.dataclass class Config: d_model: int = 768 From 93c224660ebf0a8b604095a9e812a60529d106a8 Mon Sep 17 00:00:00 2001 From: Joseph Bloom <69127271+jbloomAus@users.noreply.github.com> Date: Tue, 26 Mar 2024 19:19:44 +0000 Subject: [PATCH 38/73] make tests pass mps (#528) --- tests/acceptance/test_hooked_transformer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index 7f3878161..b267a3278 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -311,8 +311,12 @@ def check_dtype(dtype, margin, no_processing=False): gc.collect() +@pytest.mark.skipif( + torch.backends.mps.is_available() or not torch.cuda.is_available(), + reason="some operations unsupported by MPS: https://github.com/pytorch/pytorch/issues/77754 or no GPU", +) @pytest.mark.parametrize("dtype", [torch.float64, torch.float32]) -def test_dtypes(dtype): +def test_dtype_float(dtype): check_dtype(dtype, margin=5e-4) From f6892d43da669045251f427f3e0c311e1a795911 Mon Sep 17 00:00:00 2001 From: sheikheddy Date: Thu, 28 Mar 2024 09:45:03 -0700 Subject: [PATCH 39/73] Add support for Llama-2-70b-chat-hf (#525) * Add support for Llama-2-70b-chat-hf * Add othello again --- transformer_lens/loading_from_pretrained.py | 53 +++++++++++++++++---- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 4bad9c179..00b4483f5 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -115,6 +115,7 @@ "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-13b-hf", "meta-llama/Llama-2-13b-chat-hf", + "meta-llama/Llama-2-70b-chat-hf", "CodeLlama-7b-hf", "CodeLlama-7b-Python-hf", "CodeLlama-7b-Instruct-hf", @@ -502,6 +503,7 @@ "Llama-2-13b-chat", "meta-llama/Llama-2-13b-chat-hf", ], + "meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"], "CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"], "CodeLlama-7b-Python-hf": [ "CodeLlama-7b-python", @@ -511,7 +513,6 @@ "CodeLlama-7b-instruct", "codellama/CodeLlama-7b-Instruct-hf", ], - # TODO Llama-2-70b-hf requires Grouped-Query Attention, see the paper https://arxiv.org/pdf/2307.09288.pdf "Baidicoot/Othello-GPT-Transformer-Lens": ["othello-gpt"], "roneneldan/TinyStories-1M": ["tiny-stories-1M"], "roneneldan/TinyStories-3M": ["tiny-stories-3M"], @@ -746,6 +747,25 @@ def convert_hf_model_config(model_name: str, **kwargs): "final_rms": True, "gated_mlp": True, } + elif "Llama-2-70b" in official_model_name: + cfg_dict = { + "d_model": 8192, + "d_head": 128, + "n_heads": 64, + "d_mlp": 28672, + "n_layers": 80, + "n_ctx": 4096, + "eps": 1e-5, + "d_vocab": 32000, + "act_fn": "silu", + "n_key_value_heads": 8, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "rotary_adjacent_pairs": False, + "rotary_dim": 128, + "final_rms": True, + "gated_mlp": True, + } elif architecture == "GPTNeoForCausalLM": cfg_dict = { "d_model": hf_config.hidden_size, @@ -1642,6 +1662,9 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig): state_dict["embed.W_E"] = llama.model.embed_tokens.weight + using_gqa = cfg.n_key_value_heads is not None + gqa_uscore = "_" if using_gqa else "" + # llama has no biases anywhere and deals with everything else roughly like # GPTNeoX with different names @@ -1652,20 +1675,30 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig): W_K = llama.model.layers[l].self_attn.k_proj.weight W_V = llama.model.layers[l].self_attn.v_proj.weight W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) - W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_heads) - W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_heads) + W_K = einops.rearrange( + W_K, "(n h) m->n m h", n=cfg.n_key_value_heads if using_gqa else cfg.n_heads + ) + W_V = einops.rearrange( + W_V, "(n h) m-n m h", n=cfg.n_key_value_heads if using_gqa else cfg.n_heads + ) state_dict[f"blocks.{l}.attn.W_Q"] = W_Q - state_dict[f"blocks.{l}.attn.W_K"] = W_K - state_dict[f"blocks.{l}.attn.W_V"] = W_V + state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K + state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device ) - state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros( - cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device - ) - state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros( - cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device + state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros( + cfg.n_key_value_heads if using_gqa else cfg.n_heads, + cfg.d_head, + dtype=cfg.dtype, + device=cfg.device, + ) + state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros( + cfg.n_key_value_heads if using_gqa else cfg.n_heads, + cfg.d_head, + dtype=cfg.dtype, + device=cfg.device, ) W_O = llama.model.layers[l].self_attn.o_proj.weight From edf40dfb01453bc4333a49e6ece5cc55d8aa2639 Mon Sep 17 00:00:00 2001 From: Joseph Bloom <69127271+jbloomAus@users.noreply.github.com> Date: Thu, 28 Mar 2024 20:55:35 +0000 Subject: [PATCH 40/73] Update loading_from_pretrained.py (#529) fix missing arrow in PR #525 Thanks @collingray --- transformer_lens/loading_from_pretrained.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 00b4483f5..fa408bc96 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1679,7 +1679,7 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig): W_K, "(n h) m->n m h", n=cfg.n_key_value_heads if using_gqa else cfg.n_heads ) W_V = einops.rearrange( - W_V, "(n h) m-n m h", n=cfg.n_key_value_heads if using_gqa else cfg.n_heads + W_V, "(n h) m->n m h", n=cfg.n_key_value_heads if using_gqa else cfg.n_heads ) state_dict[f"blocks.{l}.attn.W_Q"] = W_Q state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K From 14b8e2effa6143eb6ca64a0bc986e1d32f1ef488 Mon Sep 17 00:00:00 2001 From: Toni Kukurin Date: Tue, 2 Apr 2024 01:26:33 +0200 Subject: [PATCH 41/73] Bugfix: pytest import (#532) * disregard pytest import fail * also fix typo --- transformer_lens/utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index c3fe2963e..c549d6066 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -2,6 +2,7 @@ This module contains varied utility functions used throughout the library. """ + from __future__ import annotations import inspect @@ -13,7 +14,6 @@ import einops import numpy as np -import pytest import torch import torch.nn.functional as F import transformers @@ -600,9 +600,6 @@ def remove_batch_dim( return tensor -# Note: Docstring won't be tested with PyTest (it's ignored), as it thinks this is a regular unit -# test (because it's name is prefixed `test_`). -@pytest.mark.skip def test_prompt( prompt: str, answer: str, @@ -1141,3 +1138,13 @@ def get_tokens_with_bos_removed(tokenizer, tokens): dim=1, index=real_bos_positions.unsqueeze(-1), value=-100 ) return tokens[tokens != -100].view(*bos_removed_shape) + + +try: + import pytest + + # Note: Docstring won't be tested with PyTest (it's ignored), as it thinks this is a regular unit + # test (because its name is prefixed `test_`). + pytest.mark.skip(test_prompt) +except ModuleNotFoundError: + pass # disregard if pytest not in env From 5bf9acbb9c2fd6dcc9ea84c2b755b2a34fccfc82 Mon Sep 17 00:00:00 2001 From: Vasil Georgiev <149842188+VasilGeorgiev39@users.noreply.github.com> Date: Tue, 2 Apr 2024 12:22:57 +0300 Subject: [PATCH 42/73] Remove non-existing parameter from decompose_resid documentation (#504) --- transformer_lens/ActivationCache.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index 7ee732005..cf25f4eeb 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -607,9 +607,6 @@ def decompose_resid( layer==n_layers means to return all layer outputs incl in the final layer, layer==0 means just embed and pos_embed. The indices are taken such that this gives the accumulated streams up to the input to layer l - incl_mid: - Whether to return resid_mid for all previous - layers. mlp_input: Whether to include attn_out for the current layer - essentially decomposing the residual stream that's input to the MLP input From f773b29f10ee07fe4f5497cf8a2070e07da298f3 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Tue, 2 Apr 2024 10:24:17 +0100 Subject: [PATCH 43/73] Add `@overload` to `FactoredMatrix.__{,r}matmul__` (#512) This should `FactoredMatrix.__matmul__` to be used in `functools.reduce` without triggering type errors. --- transformer_lens/FactoredMatrix.py | 37 +++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/transformer_lens/FactoredMatrix.py b/transformer_lens/FactoredMatrix.py index c097646c4..7f72a5df6 100644 --- a/transformer_lens/FactoredMatrix.py +++ b/transformer_lens/FactoredMatrix.py @@ -3,10 +3,11 @@ Utilities for representing a matrix as a product of two matrices, and for efficient calculation of eigenvalues, norm and SVD. """ + from __future__ import annotations from functools import lru_cache -from typing import List, Tuple, Union +from typing import List, Tuple, Union, overload import torch from jaxtyping import Float @@ -40,6 +41,23 @@ def __init__( self.A = self.A.broadcast_to(self.shape[:-2] + (self.ldim, self.mdim)) self.B = self.B.broadcast_to(self.shape[:-2] + (self.mdim, self.rdim)) + @overload + def __matmul__( + self, + other: Union[ + Float[torch.Tensor, "... rdim new_rdim"], + "FactoredMatrix", + ], + ) -> "FactoredMatrix": + ... + + @overload + def __matmul__( + self, + other: Float[torch.Tensor, "rdim"], + ) -> Float[torch.Tensor, "... ldim"]: + ... + def __matmul__( self, other: Union[ @@ -64,6 +82,23 @@ def __matmul__( elif isinstance(other, FactoredMatrix): return (self @ other.A) @ other.B + @overload + def __rmatmul__( + self, + other: Union[ + Float[torch.Tensor, "... new_rdim ldim"], + "FactoredMatrix", + ], + ) -> "FactoredMatrix": + ... + + @overload + def __rmatmul__( + self, + other: Float[torch.Tensor, "ldim"], + ) -> Float[torch.Tensor, "... rdim"]: + ... + def __rmatmul__( self, other: Union[ From 42c160237269829ba8b8181a21ea44b58007147c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20Hofst=C3=A4tter?= Date: Tue, 2 Apr 2024 02:26:00 -0700 Subject: [PATCH 44/73] Explain abstract attribute in more detail (#508) --- transformer_lens/components.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_lens/components.py b/transformer_lens/components.py index 942ec2819..cdda30ccd 100644 --- a/transformer_lens/components.py +++ b/transformer_lens/components.py @@ -393,7 +393,8 @@ def __init__( """Abstract Base Class of Attention Blocks, featuring common functionality of both Attention and GroupedQueryAttention blocks. Query and Output projections are defined in this class as they are the same for regular and grouped query attention. - Attributes related to Key and Value projections are abstract as their implementations may differ. + Attributes related to Key and Value projections are abstract as their implementations may differ. For example, in GroupedQueryAttention there are less query and key heads than value heads. + To enforce implementation of W_K, W_V, b_K, and b_V by child classes, the better_abc.abstract_attribute class is used. See here for details: https://stackoverflow.com/questions/23831510/abstract-attribute-not-property. Args: cfg (Union[Dict, HookedTransformerConfig]): Config From 3f5db9f86459fde97ae54265a849a11aa4f72daf Mon Sep 17 00:00:00 2001 From: Vasil Georgiev <149842188+VasilGeorgiev39@users.noreply.github.com> Date: Wed, 3 Apr 2024 02:36:22 +0300 Subject: [PATCH 45/73] Add pos_slice to run_with_cache (#465) * Add pos_slice to run_with_cache * Fixed formatting * Fix slicing res stream tensors with too few dimensions * Fix pos_slice in run_with_cache colapses the pos dimension * Fix run_with_cache with pos_slice now correctly slices attention heads * Add unit tests for run_with_cache with pos_slice * Fix format * Add documentation for pos_slice in run_with_cache --------- Co-authored-by: Bryce Meyer --- tests/unit/test_cache_pos_slice.py | 269 +++++++++++++++++++++++++++++ transformer_lens/hook_points.py | 57 ++++-- 2 files changed, 315 insertions(+), 11 deletions(-) create mode 100644 tests/unit/test_cache_pos_slice.py diff --git a/tests/unit/test_cache_pos_slice.py b/tests/unit/test_cache_pos_slice.py new file mode 100644 index 000000000..41e6982b7 --- /dev/null +++ b/tests/unit/test_cache_pos_slice.py @@ -0,0 +1,269 @@ +# %% + +import torch + +from transformer_lens import HookedTransformer + +MODEL = "tiny-stories-1M" + +prompt = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum." +model = HookedTransformer.from_pretrained(MODEL) +# %% +d_model = model.cfg.d_model +d_head = model.cfg.d_head +n_heads = model.cfg.n_heads +n_layers = model.cfg.n_layers +# %% + + +def test_run_with_cache_pos_slice_keep_batch(): + _, cache_no_slice = model.run_with_cache(prompt, return_type=None) + num_tokens = len(model.tokenizer.encode(prompt)) + + for i in range(-1, num_tokens + 1): + _, cache_with_slice = model.run_with_cache( + prompt, return_type=None, pos_slice=i + ) + + assert cache_with_slice["embed"].shape == torch.Size([1, 1, d_model]) + assert cache_with_slice["q", 0].shape == torch.Size([1, 1, n_heads, d_head]) + + assert torch.equal( + cache_no_slice["embed"][0, i, :], cache_with_slice["embed"][0, 0, :] + ) + assert torch.equal( + cache_no_slice["pos_embed"][0, i, :], cache_with_slice["pos_embed"][0, 0, :] + ) + + for layer in range(n_layers): + assert torch.equal( + cache_no_slice["resid_pre", layer][0, i, :], + cache_with_slice["resid_pre", layer][0, 0, :], + ) + assert torch.equal( + cache_no_slice["resid_post", layer][0, i, :], + cache_with_slice["resid_post", layer][0, 0, :], + ) + assert torch.equal( + cache_no_slice["resid_mid", layer][0, i, :], + cache_with_slice["resid_mid", layer][0, 0, :], + ) + assert torch.equal( + cache_no_slice["scale", layer, "ln1"][0, i, :], + cache_with_slice["scale", layer, "ln1"][0, 0, :], + ) + assert torch.equal( + cache_no_slice["scale", layer, "ln2"][0, i, :], + cache_with_slice["scale", layer, "ln2"][0, 0, :], + ) + assert torch.equal( + cache_no_slice["normalized", layer, "ln1"][0, i, :], + cache_with_slice["normalized", layer, "ln1"][0, 0, :], + ) + assert torch.equal( + cache_no_slice["normalized", layer, "ln2"][0, i, :], + cache_with_slice["normalized", layer, "ln2"][0, 0, :], + ) + assert torch.equal( + cache_no_slice[ + "q", + layer, + ][0, i, :, :], + cache_with_slice[ + "q", + layer, + ][0, 0, :, :], + ) + assert torch.equal( + cache_no_slice[ + "k", + layer, + ][0, i, :, :], + cache_with_slice[ + "k", + layer, + ][0, 0, :, :], + ) + assert torch.equal( + cache_no_slice[ + "v", + layer, + ][0, i, :, :], + cache_with_slice[ + "v", + layer, + ][0, 0, :, :], + ) + assert torch.equal( + cache_no_slice[ + "z", + layer, + ][0, i, :, :], + cache_with_slice[ + "z", + layer, + ][0, 0, :, :], + ) + assert torch.equal( + cache_no_slice[ + "attn_scores", + layer, + ][0, :, i, :], + cache_with_slice[ + "attn_scores", + layer, + ][0, :, 0, :], + ) + assert torch.equal( + cache_no_slice[ + "pattern", + layer, + ][0, :, i, :], + cache_with_slice[ + "pattern", + layer, + ][0, :, 0, :], + ) + assert torch.equal( + cache_no_slice["attn_out", layer][0, i, :], + cache_with_slice["attn_out", layer][0, 0, :], + ) + assert torch.equal( + cache_no_slice["pre", layer][0, i, :], + cache_with_slice["pre", layer][0, 0, :], + ) + assert torch.equal( + cache_no_slice["post", layer][0, i, :], + cache_with_slice["post", layer][0, 0, :], + ) + assert torch.equal( + cache_no_slice["mlp_out", layer][0, i, :], + cache_with_slice["mlp_out", layer][0, 0, :], + ) + + +def test_run_with_cache_pos_slice_remove_batch(): + _, cache_no_slice = model.run_with_cache( + prompt, remove_batch_dim=True, return_type=None + ) + num_tokens = len(model.tokenizer.encode(prompt)) + + for i in range(-1, num_tokens + 1): + _, cache_with_slice = model.run_with_cache( + prompt, remove_batch_dim=True, pos_slice=i + ) + + assert cache_with_slice["embed"].shape == torch.Size([1, d_model]) + assert cache_with_slice["q", 0].shape == torch.Size([1, n_heads, d_head]) + + assert torch.equal( + cache_no_slice["embed"][i, :], cache_with_slice["embed"][0, :] + ) + assert torch.equal( + cache_no_slice["pos_embed"][i, :], cache_with_slice["pos_embed"][0, :] + ) + + for layer in range(n_layers): + assert torch.equal( + cache_no_slice["resid_pre", layer][i, :], + cache_with_slice["resid_pre", layer][0, :], + ) + assert torch.equal( + cache_no_slice["resid_post", layer][i, :], + cache_with_slice["resid_post", layer][0, :], + ) + assert torch.equal( + cache_no_slice["resid_mid", layer][i, :], + cache_with_slice["resid_mid", layer][0, :], + ) + assert torch.equal( + cache_no_slice["scale", layer, "ln1"][i, :], + cache_with_slice["scale", layer, "ln1"][0, :], + ) + assert torch.equal( + cache_no_slice["scale", layer, "ln2"][i, :], + cache_with_slice["scale", layer, "ln2"][0, :], + ) + assert torch.equal( + cache_no_slice["normalized", layer, "ln1"][i, :], + cache_with_slice["normalized", layer, "ln1"][0, :], + ) + assert torch.equal( + cache_no_slice["normalized", layer, "ln2"][i, :], + cache_with_slice["normalized", layer, "ln2"][0, :], + ) + assert torch.equal( + cache_no_slice[ + "q", + layer, + ][i, :, :], + cache_with_slice[ + "q", + layer, + ][0, :, :], + ) + assert torch.equal( + cache_no_slice[ + "k", + layer, + ][i, :, :], + cache_with_slice[ + "k", + layer, + ][0, :, :], + ) + assert torch.equal( + cache_no_slice[ + "v", + layer, + ][i, :, :], + cache_with_slice[ + "v", + layer, + ][0, :, :], + ) + assert torch.equal( + cache_no_slice[ + "z", + layer, + ][i, :, :], + cache_with_slice[ + "z", + layer, + ][0, :, :], + ) + assert torch.equal( + cache_no_slice[ + "attn_scores", + layer, + ][:, i, :], + cache_with_slice[ + "attn_scores", + layer, + ][:, 0, :], + ) + assert torch.equal( + cache_no_slice[ + "pattern", + layer, + ][:, i, :], + cache_with_slice[ + "pattern", + layer, + ][:, 0, :], + ) + assert torch.equal( + cache_no_slice["attn_out", layer][i, :], + cache_with_slice["attn_out", layer][0, :], + ) + assert torch.equal( + cache_no_slice["pre", layer][i, :], cache_with_slice["pre", layer][0, :] + ) + assert torch.equal( + cache_no_slice["post", layer][i, :], + cache_with_slice["post", layer][0, :], + ) + assert torch.equal( + cache_no_slice["mlp_out", layer][i, :], + cache_with_slice["mlp_out", layer][0, :], + ) diff --git a/transformer_lens/hook_points.py b/transformer_lens/hook_points.py index ce86fd77a..313c6d4d6 100644 --- a/transformer_lens/hook_points.py +++ b/transformer_lens/hook_points.py @@ -5,11 +5,14 @@ import logging from contextlib import contextmanager from dataclasses import dataclass +from functools import partial from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import torch.nn as nn import torch.utils.hooks as hooks +from transformer_lens.utils import Slice + @dataclass class LensHandle: @@ -426,6 +429,7 @@ def run_with_cache( incl_bwd=False, reset_hooks_end=True, clear_contexts=False, + pos_slice=None, **model_kwargs, ): """ @@ -448,14 +452,28 @@ def run_with_cache( end of the run. Defaults to True. clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset. Defaults to False. + pos_slice: + The slice to apply to the cache output. Defaults to None, do nothing. **model_kwargs: Keyword arguments for the model. Returns: tuple: A tuple containing the model output and a Cache object. """ + + if not isinstance(pos_slice, Slice): + if isinstance( + pos_slice, int + ): # slicing with an int collapses the dimension so this stops the pos dimension from collapsing + pos_slice = [pos_slice] + pos_slice = Slice(pos_slice) + cache_dict, fwd, bwd = self.get_caching_hooks( - names_filter, incl_bwd, device, remove_batch_dim=remove_batch_dim + names_filter, + incl_bwd, + device, + remove_batch_dim=remove_batch_dim, + pos_slice=pos_slice, ) with self.hooks( @@ -477,6 +495,7 @@ def get_caching_hooks( device=None, remove_batch_dim: bool = False, cache: Optional[dict] = None, + pos_slice: Slice = None, ) -> Tuple[dict, list, list]: """Creates hooks to cache activations. Note: It does not add the hooks to the model. @@ -505,25 +524,41 @@ def get_caching_hooks( names_filter = lambda name: name in filter_list self.is_caching = True - def save_hook(tensor, hook): + def save_hook(tensor, hook, is_backward=False): + hook_name = hook.name + if is_backward: + hook_name += "_grad" + resid_stream = tensor.detach().to(device) if remove_batch_dim: - cache[hook.name] = tensor.detach().to(device)[0] + resid_stream = resid_stream[0] + + # for attention heads the pos dimension is the third from last + if ( + hook.name.endswith("hook_q") + or hook.name.endswith("hook_k") + or hook.name.endswith("hook_v") + or hook.name.endswith("hook_z") + or hook.name.endswith("hook_result") + ): + pos_dim = -3 else: - cache[hook.name] = tensor.detach().to(device) + # for all other components the pos dimension is the second from last + # including the attn scores where the dest token is the second from last + pos_dim = -2 - def save_hook_back(tensor, hook): - if remove_batch_dim: - cache[hook.name + "_grad"] = tensor.detach().to(device)[0] - else: - cache[hook.name + "_grad"] = tensor.detach().to(device) + if ( + tensor.dim() >= -pos_dim + ): # check if the residual stream has a pos dimension before trying to slice + resid_stream = pos_slice.apply(resid_stream, dim=pos_dim) + cache[hook_name] = resid_stream fwd_hooks = [] bwd_hooks = [] for name, hp in self.hook_dict.items(): if names_filter(name): - fwd_hooks.append((name, save_hook)) + fwd_hooks.append((name, partial(save_hook, is_backward=False))) if incl_bwd: - bwd_hooks.append((name, save_hook_back)) + bwd_hooks.append((name, partial(save_hook, is_backward=True))) return cache, fwd_hooks, bwd_hooks From 72d5ae33a6611e74223e5fa228b5e2956c079fad Mon Sep 17 00:00:00 2001 From: Collin Date: Tue, 2 Apr 2024 17:08:30 -0700 Subject: [PATCH 46/73] Add Support for Yi-6B and Yi-34B (#494) * add LlamaForCausalLM arch. parsing and 01-ai/Yi * fix attn bias dim error * fix attn dim error... again * add chat models * format * add sentencepiece for yi-chat tokenizers * update poetry.lock * update gqa comment * update poetry.lock --------- Co-authored-by: Bryce Meyer --- poetry.lock | 2330 ++++++++++--------- pyproject.toml | 1 + transformer_lens/loading_from_pretrained.py | 34 + 3 files changed, 1284 insertions(+), 1081 deletions(-) diff --git a/poetry.lock b/poetry.lock index af0799dd7..c5e5c03a3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,13 +2,13 @@ [[package]] name = "accelerate" -version = "0.25.0" +version = "0.28.0" description = "Accelerate" optional = false python-versions = ">=3.8.0" files = [ - {file = "accelerate-0.25.0-py3-none-any.whl", hash = "sha256:c7bb817eb974bba0ff3ea1ba0f24d55afb86d50e3d4fe98d6922dc69cf2ccff1"}, - {file = "accelerate-0.25.0.tar.gz", hash = "sha256:ecf55b0ab278a1dac8539dde0d276977aff04683f07ede73eaf02478538576a1"}, + {file = "accelerate-0.28.0-py3-none-any.whl", hash = "sha256:8ae25f8a8dc4cf12283842c469113836300545fb0dfa46fef331fb0a2ac8b421"}, + {file = "accelerate-0.28.0.tar.gz", hash = "sha256:32019a49f4b3a85cc179ac4e38e9e2971f1a997dee026be0512816499464c4d5"}, ] [package.dependencies] @@ -21,98 +21,98 @@ safetensors = ">=0.3.1" torch = ">=1.10.0" [package.extras] -dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "deepspeed", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.0.241)", "scikit-learn", "scipy", "timm", "tqdm", "transformers", "urllib3 (<2.0.0)"] -quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.0.241)", "urllib3 (<2.0.0)"] +dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "deepspeed (<0.13.0)", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.2.1,<0.3.0)", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] +quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.2.1,<0.3.0)"] rich = ["rich"] sagemaker = ["sagemaker"] -test-dev = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "scikit-learn", "scipy", "timm", "tqdm", "transformers"] -test-prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"] +test-dev = ["bitsandbytes", "datasets", "deepspeed (<0.13.0)", "evaluate", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] +test-prod = ["parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist"] test-trackers = ["comet-ml", "dvclive", "tensorboard", "wandb"] -testing = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "tqdm", "transformers"] +testing = ["bitsandbytes", "datasets", "deepspeed (<0.13.0)", "evaluate", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] [[package]] name = "aiohttp" -version = "3.9.1" +version = "3.9.3" description = "Async http client/server framework (asyncio)" optional = false python-versions = ">=3.8" files = [ - {file = "aiohttp-3.9.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e1f80197f8b0b846a8d5cf7b7ec6084493950d0882cc5537fb7b96a69e3c8590"}, - {file = "aiohttp-3.9.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c72444d17777865734aa1a4d167794c34b63e5883abb90356a0364a28904e6c0"}, - {file = "aiohttp-3.9.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9b05d5cbe9dafcdc733262c3a99ccf63d2f7ce02543620d2bd8db4d4f7a22f83"}, - {file = "aiohttp-3.9.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c4fa235d534b3547184831c624c0b7c1e262cd1de847d95085ec94c16fddcd5"}, - {file = "aiohttp-3.9.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:289ba9ae8e88d0ba16062ecf02dd730b34186ea3b1e7489046fc338bdc3361c4"}, - {file = "aiohttp-3.9.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bff7e2811814fa2271be95ab6e84c9436d027a0e59665de60edf44e529a42c1f"}, - {file = "aiohttp-3.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81b77f868814346662c96ab36b875d7814ebf82340d3284a31681085c051320f"}, - {file = "aiohttp-3.9.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b9c7426923bb7bd66d409da46c41e3fb40f5caf679da624439b9eba92043fa6"}, - {file = "aiohttp-3.9.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:8d44e7bf06b0c0a70a20f9100af9fcfd7f6d9d3913e37754c12d424179b4e48f"}, - {file = "aiohttp-3.9.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:22698f01ff5653fe66d16ffb7658f582a0ac084d7da1323e39fd9eab326a1f26"}, - {file = "aiohttp-3.9.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ca7ca5abfbfe8d39e653870fbe8d7710be7a857f8a8386fc9de1aae2e02ce7e4"}, - {file = "aiohttp-3.9.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:8d7f98fde213f74561be1d6d3fa353656197f75d4edfbb3d94c9eb9b0fc47f5d"}, - {file = "aiohttp-3.9.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5216b6082c624b55cfe79af5d538e499cd5f5b976820eac31951fb4325974501"}, - {file = "aiohttp-3.9.1-cp310-cp310-win32.whl", hash = "sha256:0e7ba7ff228c0d9a2cd66194e90f2bca6e0abca810b786901a569c0de082f489"}, - {file = "aiohttp-3.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:c7e939f1ae428a86e4abbb9a7c4732bf4706048818dfd979e5e2839ce0159f23"}, - {file = "aiohttp-3.9.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:df9cf74b9bc03d586fc53ba470828d7b77ce51b0582d1d0b5b2fb673c0baa32d"}, - {file = "aiohttp-3.9.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ecca113f19d5e74048c001934045a2b9368d77b0b17691d905af18bd1c21275e"}, - {file = "aiohttp-3.9.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8cef8710fb849d97c533f259103f09bac167a008d7131d7b2b0e3a33269185c0"}, - {file = "aiohttp-3.9.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bea94403a21eb94c93386d559bce297381609153e418a3ffc7d6bf772f59cc35"}, - {file = "aiohttp-3.9.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91c742ca59045dce7ba76cab6e223e41d2c70d79e82c284a96411f8645e2afff"}, - {file = "aiohttp-3.9.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6c93b7c2e52061f0925c3382d5cb8980e40f91c989563d3d32ca280069fd6a87"}, - {file = "aiohttp-3.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee2527134f95e106cc1653e9ac78846f3a2ec1004cf20ef4e02038035a74544d"}, - {file = "aiohttp-3.9.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:11ff168d752cb41e8492817e10fb4f85828f6a0142b9726a30c27c35a1835f01"}, - {file = "aiohttp-3.9.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b8c3a67eb87394386847d188996920f33b01b32155f0a94f36ca0e0c635bf3e3"}, - {file = "aiohttp-3.9.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c7b5d5d64e2a14e35a9240b33b89389e0035e6de8dbb7ffa50d10d8b65c57449"}, - {file = "aiohttp-3.9.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:69985d50a2b6f709412d944ffb2e97d0be154ea90600b7a921f95a87d6f108a2"}, - {file = "aiohttp-3.9.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:c9110c06eaaac7e1f5562caf481f18ccf8f6fdf4c3323feab28a93d34cc646bd"}, - {file = "aiohttp-3.9.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d737e69d193dac7296365a6dcb73bbbf53bb760ab25a3727716bbd42022e8d7a"}, - {file = "aiohttp-3.9.1-cp311-cp311-win32.whl", hash = "sha256:4ee8caa925aebc1e64e98432d78ea8de67b2272252b0a931d2ac3bd876ad5544"}, - {file = "aiohttp-3.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:a34086c5cc285be878622e0a6ab897a986a6e8bf5b67ecb377015f06ed316587"}, - {file = "aiohttp-3.9.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f800164276eec54e0af5c99feb9494c295118fc10a11b997bbb1348ba1a52065"}, - {file = "aiohttp-3.9.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:500f1c59906cd142d452074f3811614be04819a38ae2b3239a48b82649c08821"}, - {file = "aiohttp-3.9.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0b0a6a36ed7e164c6df1e18ee47afbd1990ce47cb428739d6c99aaabfaf1b3af"}, - {file = "aiohttp-3.9.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69da0f3ed3496808e8cbc5123a866c41c12c15baaaead96d256477edf168eb57"}, - {file = "aiohttp-3.9.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:176df045597e674fa950bf5ae536be85699e04cea68fa3a616cf75e413737eb5"}, - {file = "aiohttp-3.9.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b796b44111f0cab6bbf66214186e44734b5baab949cb5fb56154142a92989aeb"}, - {file = "aiohttp-3.9.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f27fdaadce22f2ef950fc10dcdf8048407c3b42b73779e48a4e76b3c35bca26c"}, - {file = "aiohttp-3.9.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bcb6532b9814ea7c5a6a3299747c49de30e84472fa72821b07f5a9818bce0f66"}, - {file = "aiohttp-3.9.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:54631fb69a6e44b2ba522f7c22a6fb2667a02fd97d636048478db2fd8c4e98fe"}, - {file = "aiohttp-3.9.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4b4c452d0190c5a820d3f5c0f3cd8a28ace48c54053e24da9d6041bf81113183"}, - {file = "aiohttp-3.9.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:cae4c0c2ca800c793cae07ef3d40794625471040a87e1ba392039639ad61ab5b"}, - {file = "aiohttp-3.9.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:565760d6812b8d78d416c3c7cfdf5362fbe0d0d25b82fed75d0d29e18d7fc30f"}, - {file = "aiohttp-3.9.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:54311eb54f3a0c45efb9ed0d0a8f43d1bc6060d773f6973efd90037a51cd0a3f"}, - {file = "aiohttp-3.9.1-cp312-cp312-win32.whl", hash = "sha256:85c3e3c9cb1d480e0b9a64c658cd66b3cfb8e721636ab8b0e746e2d79a7a9eed"}, - {file = "aiohttp-3.9.1-cp312-cp312-win_amd64.whl", hash = "sha256:11cb254e397a82efb1805d12561e80124928e04e9c4483587ce7390b3866d213"}, - {file = "aiohttp-3.9.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8a22a34bc594d9d24621091d1b91511001a7eea91d6652ea495ce06e27381f70"}, - {file = "aiohttp-3.9.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:598db66eaf2e04aa0c8900a63b0101fdc5e6b8a7ddd805c56d86efb54eb66672"}, - {file = "aiohttp-3.9.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2c9376e2b09895c8ca8b95362283365eb5c03bdc8428ade80a864160605715f1"}, - {file = "aiohttp-3.9.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41473de252e1797c2d2293804e389a6d6986ef37cbb4a25208de537ae32141dd"}, - {file = "aiohttp-3.9.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9c5857612c9813796960c00767645cb5da815af16dafb32d70c72a8390bbf690"}, - {file = "aiohttp-3.9.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ffcd828e37dc219a72c9012ec44ad2e7e3066bec6ff3aaa19e7d435dbf4032ca"}, - {file = "aiohttp-3.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:219a16763dc0294842188ac8a12262b5671817042b35d45e44fd0a697d8c8361"}, - {file = "aiohttp-3.9.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f694dc8a6a3112059258a725a4ebe9acac5fe62f11c77ac4dcf896edfa78ca28"}, - {file = "aiohttp-3.9.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:bcc0ea8d5b74a41b621ad4a13d96c36079c81628ccc0b30cfb1603e3dfa3a014"}, - {file = "aiohttp-3.9.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:90ec72d231169b4b8d6085be13023ece8fa9b1bb495e4398d847e25218e0f431"}, - {file = "aiohttp-3.9.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:cf2a0ac0615842b849f40c4d7f304986a242f1e68286dbf3bd7a835e4f83acfd"}, - {file = "aiohttp-3.9.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:0e49b08eafa4f5707ecfb321ab9592717a319e37938e301d462f79b4e860c32a"}, - {file = "aiohttp-3.9.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2c59e0076ea31c08553e868cec02d22191c086f00b44610f8ab7363a11a5d9d8"}, - {file = "aiohttp-3.9.1-cp38-cp38-win32.whl", hash = "sha256:4831df72b053b1eed31eb00a2e1aff6896fb4485301d4ccb208cac264b648db4"}, - {file = "aiohttp-3.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:3135713c5562731ee18f58d3ad1bf41e1d8883eb68b363f2ffde5b2ea4b84cc7"}, - {file = "aiohttp-3.9.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cfeadf42840c1e870dc2042a232a8748e75a36b52d78968cda6736de55582766"}, - {file = "aiohttp-3.9.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:70907533db712f7aa791effb38efa96f044ce3d4e850e2d7691abd759f4f0ae0"}, - {file = "aiohttp-3.9.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cdefe289681507187e375a5064c7599f52c40343a8701761c802c1853a504558"}, - {file = "aiohttp-3.9.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7481f581251bb5558ba9f635db70908819caa221fc79ee52a7f58392778c636"}, - {file = "aiohttp-3.9.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:49f0c1b3c2842556e5de35f122fc0f0b721334ceb6e78c3719693364d4af8499"}, - {file = "aiohttp-3.9.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0d406b01a9f5a7e232d1b0d161b40c05275ffbcbd772dc18c1d5a570961a1ca4"}, - {file = "aiohttp-3.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d8e4450e7fe24d86e86b23cc209e0023177b6d59502e33807b732d2deb6975f"}, - {file = "aiohttp-3.9.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c0266cd6f005e99f3f51e583012de2778e65af6b73860038b968a0a8888487a"}, - {file = "aiohttp-3.9.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab221850108a4a063c5b8a70f00dd7a1975e5a1713f87f4ab26a46e5feac5a0e"}, - {file = "aiohttp-3.9.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c88a15f272a0ad3d7773cf3a37cc7b7d077cbfc8e331675cf1346e849d97a4e5"}, - {file = "aiohttp-3.9.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:237533179d9747080bcaad4d02083ce295c0d2eab3e9e8ce103411a4312991a0"}, - {file = "aiohttp-3.9.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:02ab6006ec3c3463b528374c4cdce86434e7b89ad355e7bf29e2f16b46c7dd6f"}, - {file = "aiohttp-3.9.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04fa38875e53eb7e354ece1607b1d2fdee2d175ea4e4d745f6ec9f751fe20c7c"}, - {file = "aiohttp-3.9.1-cp39-cp39-win32.whl", hash = "sha256:82eefaf1a996060602f3cc1112d93ba8b201dbf5d8fd9611227de2003dddb3b7"}, - {file = "aiohttp-3.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:9b05d33ff8e6b269e30a7957bd3244ffbce2a7a35a81b81c382629b80af1a8bf"}, - {file = "aiohttp-3.9.1.tar.gz", hash = "sha256:8fc49a87ac269d4529da45871e2ffb6874e87779c3d0e2ccd813c0899221239d"}, + {file = "aiohttp-3.9.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:939677b61f9d72a4fa2a042a5eee2a99a24001a67c13da113b2e30396567db54"}, + {file = "aiohttp-3.9.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1f5cd333fcf7590a18334c90f8c9147c837a6ec8a178e88d90a9b96ea03194cc"}, + {file = "aiohttp-3.9.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:82e6aa28dd46374f72093eda8bcd142f7771ee1eb9d1e223ff0fa7177a96b4a5"}, + {file = "aiohttp-3.9.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f56455b0c2c7cc3b0c584815264461d07b177f903a04481dfc33e08a89f0c26b"}, + {file = "aiohttp-3.9.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bca77a198bb6e69795ef2f09a5f4c12758487f83f33d63acde5f0d4919815768"}, + {file = "aiohttp-3.9.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e083c285857b78ee21a96ba1eb1b5339733c3563f72980728ca2b08b53826ca5"}, + {file = "aiohttp-3.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab40e6251c3873d86ea9b30a1ac6d7478c09277b32e14745d0d3c6e76e3c7e29"}, + {file = "aiohttp-3.9.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:df822ee7feaaeffb99c1a9e5e608800bd8eda6e5f18f5cfb0dc7eeb2eaa6bbec"}, + {file = "aiohttp-3.9.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:acef0899fea7492145d2bbaaaec7b345c87753168589cc7faf0afec9afe9b747"}, + {file = "aiohttp-3.9.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:cd73265a9e5ea618014802ab01babf1940cecb90c9762d8b9e7d2cc1e1969ec6"}, + {file = "aiohttp-3.9.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:a78ed8a53a1221393d9637c01870248a6f4ea5b214a59a92a36f18151739452c"}, + {file = "aiohttp-3.9.3-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:6b0e029353361f1746bac2e4cc19b32f972ec03f0f943b390c4ab3371840aabf"}, + {file = "aiohttp-3.9.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7cf5c9458e1e90e3c390c2639f1017a0379a99a94fdfad3a1fd966a2874bba52"}, + {file = "aiohttp-3.9.3-cp310-cp310-win32.whl", hash = "sha256:3e59c23c52765951b69ec45ddbbc9403a8761ee6f57253250c6e1536cacc758b"}, + {file = "aiohttp-3.9.3-cp310-cp310-win_amd64.whl", hash = "sha256:055ce4f74b82551678291473f66dc9fb9048a50d8324278751926ff0ae7715e5"}, + {file = "aiohttp-3.9.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6b88f9386ff1ad91ace19d2a1c0225896e28815ee09fc6a8932fded8cda97c3d"}, + {file = "aiohttp-3.9.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c46956ed82961e31557b6857a5ca153c67e5476972e5f7190015018760938da2"}, + {file = "aiohttp-3.9.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:07b837ef0d2f252f96009e9b8435ec1fef68ef8b1461933253d318748ec1acdc"}, + {file = "aiohttp-3.9.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dad46e6f620574b3b4801c68255492e0159d1712271cc99d8bdf35f2043ec266"}, + {file = "aiohttp-3.9.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5ed3e046ea7b14938112ccd53d91c1539af3e6679b222f9469981e3dac7ba1ce"}, + {file = "aiohttp-3.9.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:039df344b45ae0b34ac885ab5b53940b174530d4dd8a14ed8b0e2155b9dddccb"}, + {file = "aiohttp-3.9.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7943c414d3a8d9235f5f15c22ace69787c140c80b718dcd57caaade95f7cd93b"}, + {file = "aiohttp-3.9.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:84871a243359bb42c12728f04d181a389718710129b36b6aad0fc4655a7647d4"}, + {file = "aiohttp-3.9.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:5eafe2c065df5401ba06821b9a054d9cb2848867f3c59801b5d07a0be3a380ae"}, + {file = "aiohttp-3.9.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:9d3c9b50f19704552f23b4eaea1fc082fdd82c63429a6506446cbd8737823da3"}, + {file = "aiohttp-3.9.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:f033d80bc6283092613882dfe40419c6a6a1527e04fc69350e87a9df02bbc283"}, + {file = "aiohttp-3.9.3-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:2c895a656dd7e061b2fd6bb77d971cc38f2afc277229ce7dd3552de8313a483e"}, + {file = "aiohttp-3.9.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1f5a71d25cd8106eab05f8704cd9167b6e5187bcdf8f090a66c6d88b634802b4"}, + {file = "aiohttp-3.9.3-cp311-cp311-win32.whl", hash = "sha256:50fca156d718f8ced687a373f9e140c1bb765ca16e3d6f4fe116e3df7c05b2c5"}, + {file = "aiohttp-3.9.3-cp311-cp311-win_amd64.whl", hash = "sha256:5fe9ce6c09668063b8447f85d43b8d1c4e5d3d7e92c63173e6180b2ac5d46dd8"}, + {file = "aiohttp-3.9.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:38a19bc3b686ad55804ae931012f78f7a534cce165d089a2059f658f6c91fa60"}, + {file = "aiohttp-3.9.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:770d015888c2a598b377bd2f663adfd947d78c0124cfe7b959e1ef39f5b13869"}, + {file = "aiohttp-3.9.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ee43080e75fc92bf36219926c8e6de497f9b247301bbf88c5c7593d931426679"}, + {file = "aiohttp-3.9.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:52df73f14ed99cee84865b95a3d9e044f226320a87af208f068ecc33e0c35b96"}, + {file = "aiohttp-3.9.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc9b311743a78043b26ffaeeb9715dc360335e5517832f5a8e339f8a43581e4d"}, + {file = "aiohttp-3.9.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b955ed993491f1a5da7f92e98d5dad3c1e14dc175f74517c4e610b1f2456fb11"}, + {file = "aiohttp-3.9.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:504b6981675ace64c28bf4a05a508af5cde526e36492c98916127f5a02354d53"}, + {file = "aiohttp-3.9.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a6fe5571784af92b6bc2fda8d1925cccdf24642d49546d3144948a6a1ed58ca5"}, + {file = "aiohttp-3.9.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ba39e9c8627edc56544c8628cc180d88605df3892beeb2b94c9bc857774848ca"}, + {file = "aiohttp-3.9.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:e5e46b578c0e9db71d04c4b506a2121c0cb371dd89af17a0586ff6769d4c58c1"}, + {file = "aiohttp-3.9.3-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:938a9653e1e0c592053f815f7028e41a3062e902095e5a7dc84617c87267ebd5"}, + {file = "aiohttp-3.9.3-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:c3452ea726c76e92f3b9fae4b34a151981a9ec0a4847a627c43d71a15ac32aa6"}, + {file = "aiohttp-3.9.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ff30218887e62209942f91ac1be902cc80cddb86bf00fbc6783b7a43b2bea26f"}, + {file = "aiohttp-3.9.3-cp312-cp312-win32.whl", hash = "sha256:38f307b41e0bea3294a9a2a87833191e4bcf89bb0365e83a8be3a58b31fb7f38"}, + {file = "aiohttp-3.9.3-cp312-cp312-win_amd64.whl", hash = "sha256:b791a3143681a520c0a17e26ae7465f1b6f99461a28019d1a2f425236e6eedb5"}, + {file = "aiohttp-3.9.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0ed621426d961df79aa3b963ac7af0d40392956ffa9be022024cd16297b30c8c"}, + {file = "aiohttp-3.9.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7f46acd6a194287b7e41e87957bfe2ad1ad88318d447caf5b090012f2c5bb528"}, + {file = "aiohttp-3.9.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:feeb18a801aacb098220e2c3eea59a512362eb408d4afd0c242044c33ad6d542"}, + {file = "aiohttp-3.9.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f734e38fd8666f53da904c52a23ce517f1b07722118d750405af7e4123933511"}, + {file = "aiohttp-3.9.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b40670ec7e2156d8e57f70aec34a7216407848dfe6c693ef131ddf6e76feb672"}, + {file = "aiohttp-3.9.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fdd215b7b7fd4a53994f238d0f46b7ba4ac4c0adb12452beee724ddd0743ae5d"}, + {file = "aiohttp-3.9.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:017a21b0df49039c8f46ca0971b3a7fdc1f56741ab1240cb90ca408049766168"}, + {file = "aiohttp-3.9.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e99abf0bba688259a496f966211c49a514e65afa9b3073a1fcee08856e04425b"}, + {file = "aiohttp-3.9.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:648056db9a9fa565d3fa851880f99f45e3f9a771dd3ff3bb0c048ea83fb28194"}, + {file = "aiohttp-3.9.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8aacb477dc26797ee089721536a292a664846489c49d3ef9725f992449eda5a8"}, + {file = "aiohttp-3.9.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:522a11c934ea660ff8953eda090dcd2154d367dec1ae3c540aff9f8a5c109ab4"}, + {file = "aiohttp-3.9.3-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:5bce0dc147ca85caa5d33debc4f4d65e8e8b5c97c7f9f660f215fa74fc49a321"}, + {file = "aiohttp-3.9.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4b4af9f25b49a7be47c0972139e59ec0e8285c371049df1a63b6ca81fdd216a2"}, + {file = "aiohttp-3.9.3-cp38-cp38-win32.whl", hash = "sha256:298abd678033b8571995650ccee753d9458dfa0377be4dba91e4491da3f2be63"}, + {file = "aiohttp-3.9.3-cp38-cp38-win_amd64.whl", hash = "sha256:69361bfdca5468c0488d7017b9b1e5ce769d40b46a9f4a2eed26b78619e9396c"}, + {file = "aiohttp-3.9.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0fa43c32d1643f518491d9d3a730f85f5bbaedcbd7fbcae27435bb8b7a061b29"}, + {file = "aiohttp-3.9.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:835a55b7ca49468aaaac0b217092dfdff370e6c215c9224c52f30daaa735c1c1"}, + {file = "aiohttp-3.9.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:06a9b2c8837d9a94fae16c6223acc14b4dfdff216ab9b7202e07a9a09541168f"}, + {file = "aiohttp-3.9.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:abf151955990d23f84205286938796c55ff11bbfb4ccfada8c9c83ae6b3c89a3"}, + {file = "aiohttp-3.9.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59c26c95975f26e662ca78fdf543d4eeaef70e533a672b4113dd888bd2423caa"}, + {file = "aiohttp-3.9.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f95511dd5d0e05fd9728bac4096319f80615aaef4acbecb35a990afebe953b0e"}, + {file = "aiohttp-3.9.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:595f105710293e76b9dc09f52e0dd896bd064a79346234b521f6b968ffdd8e58"}, + {file = "aiohttp-3.9.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7c8b816c2b5af5c8a436df44ca08258fc1a13b449393a91484225fcb7545533"}, + {file = "aiohttp-3.9.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f1088fa100bf46e7b398ffd9904f4808a0612e1d966b4aa43baa535d1b6341eb"}, + {file = "aiohttp-3.9.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f59dfe57bb1ec82ac0698ebfcdb7bcd0e99c255bd637ff613760d5f33e7c81b3"}, + {file = "aiohttp-3.9.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:361a1026c9dd4aba0109e4040e2aecf9884f5cfe1b1b1bd3d09419c205e2e53d"}, + {file = "aiohttp-3.9.3-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:363afe77cfcbe3a36353d8ea133e904b108feea505aa4792dad6585a8192c55a"}, + {file = "aiohttp-3.9.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8e2c45c208c62e955e8256949eb225bd8b66a4c9b6865729a786f2aa79b72e9d"}, + {file = "aiohttp-3.9.3-cp39-cp39-win32.whl", hash = "sha256:f7217af2e14da0856e082e96ff637f14ae45c10a5714b63c77f26d8884cf1051"}, + {file = "aiohttp-3.9.3-cp39-cp39-win_amd64.whl", hash = "sha256:27468897f628c627230dba07ec65dc8d0db566923c48f29e084ce382119802bc"}, + {file = "aiohttp-3.9.3.tar.gz", hash = "sha256:90842933e5d1ff760fae6caca4b2b3edba53ba8f4b71e95dacf2818a2aca06f7"}, ] [package.dependencies] @@ -153,13 +153,13 @@ files = [ [[package]] name = "anyio" -version = "4.2.0" +version = "4.3.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false python-versions = ">=3.8" files = [ - {file = "anyio-4.2.0-py3-none-any.whl", hash = "sha256:745843b39e829e108e518c489b31dc757de7d2131d53fac32bd8df268227bfee"}, - {file = "anyio-4.2.0.tar.gz", hash = "sha256:e1875bb4b4e2de1669f4bc7869b6d3f54231cdced71605e6e64c9be77e3be50f"}, + {file = "anyio-4.3.0-py3-none-any.whl", hash = "sha256:048e05d0f6caeed70d731f3db756d35dcc1f35747c8c403364a8332c630441b8"}, + {file = "anyio-4.3.0.tar.gz", hash = "sha256:f75253795a87df48568485fd18cdd2a3fa5c4f7c5be8e5e36637733fce06fed6"}, ] [package.dependencies] @@ -186,13 +186,13 @@ files = [ [[package]] name = "appnope" -version = "0.1.3" +version = "0.1.4" description = "Disable App Nap on macOS >= 10.9" optional = false -python-versions = "*" +python-versions = ">=3.6" files = [ - {file = "appnope-0.1.3-py2.py3-none-any.whl", hash = "sha256:265a455292d0bd8a72453494fa24df5a11eb18373a60c7c0430889f22548605e"}, - {file = "appnope-0.1.3.tar.gz", hash = "sha256:02bd91c4de869fbb1e1c50aafc4098827a7a54ab2f39d9dcba6c9547ed920e24"}, + {file = "appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c"}, + {file = "appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee"}, ] [[package]] @@ -381,19 +381,22 @@ test-tox-coverage = ["coverage (>=5.5)"] [[package]] name = "beautifulsoup4" -version = "4.12.2" +version = "4.12.3" description = "Screen-scraping library" optional = false python-versions = ">=3.6.0" files = [ - {file = "beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a"}, - {file = "beautifulsoup4-4.12.2.tar.gz", hash = "sha256:492bbc69dca35d12daac71c4db1bfff0c876c00ef4a2ffacce226d4638eb72da"}, + {file = "beautifulsoup4-4.12.3-py3-none-any.whl", hash = "sha256:b80878c9f40111313e55da8ba20bdba06d8fa3969fc68304167741bbf9e082ed"}, + {file = "beautifulsoup4-4.12.3.tar.gz", hash = "sha256:74e3d1928edc070d21748185c46e3fb33490f22f52a3addee9aee0f4f7781051"}, ] [package.dependencies] soupsieve = ">1.2" [package.extras] +cchardet = ["cchardet"] +chardet = ["chardet"] +charset-normalizer = ["charset-normalizer"] html5lib = ["html5lib"] lxml = ["lxml"] @@ -474,13 +477,13 @@ css = ["tinycss2 (>=1.1.0,<1.3)"] [[package]] name = "certifi" -version = "2023.11.17" +version = "2024.2.2" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" files = [ - {file = "certifi-2023.11.17-py3-none-any.whl", hash = "sha256:e036ab49d5b79556f99cfc2d9320b34cfbe5be05c5871b51de9329f0603b0474"}, - {file = "certifi-2023.11.17.tar.gz", hash = "sha256:9b469f3a900bf28dc19b8cfbf8019bf47f7fdd1a65a1d4ffb98fc14166beb4d1"}, + {file = "certifi-2024.2.2-py3-none-any.whl", hash = "sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1"}, + {file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"}, ] [[package]] @@ -648,35 +651,22 @@ files = [ [[package]] name = "circuitsvis" -version = "1.43.2" +version = "1.41.0" description = "Mechanistic Interpretability Visualizations" optional = false -python-versions = ">=3.8" +python-versions = ">=3.7,<4.0" files = [ - {file = "circuitsvis-1.43.2-py3-none-any.whl", hash = "sha256:1128fde5de8b738dd3c932d0b0ec4ee5556387b4405592fdf37f617e647183fb"}, - {file = "circuitsvis-1.43.2.tar.gz", hash = "sha256:388c1a6ea1bcf308da51fa6f67be761483ba361321d2e111f4c28faaea458287"}, + {file = "circuitsvis-1.41.0-py3-none-any.whl", hash = "sha256:53dc12c955c160b8108a0eb17ed14a34ba9f53b218457d29f351cba3db31acb7"}, + {file = "circuitsvis-1.41.0.tar.gz", hash = "sha256:386385f38d8b9de1bbef125fa282afc9157027bc2dcdc4c04feafbc22bc71d17"}, ] [package.dependencies] -importlib-metadata = ">=5.1.0" +importlib-metadata = ">=5.1.0,<6.0.0" numpy = [ - {version = ">=1.20,<1.25", markers = "python_version >= \"3.8\" and python_version < \"3.9\""}, - {version = ">=1.24", markers = "python_version >= \"3.9\" and python_version < \"3.12\""}, - {version = ">=1.26", markers = "python_version >= \"3.12\" and python_version < \"3.13\""}, + {version = ">=1.21,<2.0", markers = "python_version < \"3.10\""}, + {version = ">=1.23,<2.0", markers = "python_version >= \"3.10\""}, ] -nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nccl-cu12 = {version = "2.18.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -torch = ">=1.10" -triton = {version = "2.1.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +torch = {version = ">=1.10", markers = "python_version >= \"3.8\""} [[package]] name = "click" @@ -705,13 +695,13 @@ files = [ [[package]] name = "comm" -version = "0.2.1" +version = "0.2.2" description = "Jupyter Python Comm implementation, for usage in ipykernel, xeus-python etc." optional = false python-versions = ">=3.8" files = [ - {file = "comm-0.2.1-py3-none-any.whl", hash = "sha256:87928485c0dfc0e7976fd89fc1e187023cf587e7c353e4a9b417555b44adf021"}, - {file = "comm-0.2.1.tar.gz", hash = "sha256:0bc91edae1344d39d3661dcbc36937181fdaddb304790458f8b044dbc064b89a"}, + {file = "comm-0.2.2-py3-none-any.whl", hash = "sha256:e6fb86cb70ff661ee8c9c14e7d36d6de3b4066f1441be4063df9c5009f0a64d3"}, + {file = "comm-0.2.2.tar.gz", hash = "sha256:3fd7a84065306e07bea1773df6eb8282de51ba82f77c72f9c85716ab11fe980e"}, ] [package.dependencies] @@ -722,63 +712,63 @@ test = ["pytest"] [[package]] name = "coverage" -version = "7.4.0" +version = "7.4.4" description = "Code coverage measurement for Python" optional = false python-versions = ">=3.8" files = [ - {file = "coverage-7.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:36b0ea8ab20d6a7564e89cb6135920bc9188fb5f1f7152e94e8300b7b189441a"}, - {file = "coverage-7.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0676cd0ba581e514b7f726495ea75aba3eb20899d824636c6f59b0ed2f88c471"}, - {file = "coverage-7.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0ca5c71a5a1765a0f8f88022c52b6b8be740e512980362f7fdbb03725a0d6b9"}, - {file = "coverage-7.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7c97726520f784239f6c62506bc70e48d01ae71e9da128259d61ca5e9788516"}, - {file = "coverage-7.4.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:815ac2d0f3398a14286dc2cea223a6f338109f9ecf39a71160cd1628786bc6f5"}, - {file = "coverage-7.4.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:80b5ee39b7f0131ebec7968baa9b2309eddb35b8403d1869e08f024efd883566"}, - {file = "coverage-7.4.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:5b2ccb7548a0b65974860a78c9ffe1173cfb5877460e5a229238d985565574ae"}, - {file = "coverage-7.4.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:995ea5c48c4ebfd898eacb098164b3cc826ba273b3049e4a889658548e321b43"}, - {file = "coverage-7.4.0-cp310-cp310-win32.whl", hash = "sha256:79287fd95585ed36e83182794a57a46aeae0b64ca53929d1176db56aacc83451"}, - {file = "coverage-7.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:5b14b4f8760006bfdb6e08667af7bc2d8d9bfdb648351915315ea17645347137"}, - {file = "coverage-7.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:04387a4a6ecb330c1878907ce0dc04078ea72a869263e53c72a1ba5bbdf380ca"}, - {file = "coverage-7.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ea81d8f9691bb53f4fb4db603203029643caffc82bf998ab5b59ca05560f4c06"}, - {file = "coverage-7.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74775198b702868ec2d058cb92720a3c5a9177296f75bd97317c787daf711505"}, - {file = "coverage-7.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76f03940f9973bfaee8cfba70ac991825611b9aac047e5c80d499a44079ec0bc"}, - {file = "coverage-7.4.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:485e9f897cf4856a65a57c7f6ea3dc0d4e6c076c87311d4bc003f82cfe199d25"}, - {file = "coverage-7.4.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6ae8c9d301207e6856865867d762a4b6fd379c714fcc0607a84b92ee63feff70"}, - {file = "coverage-7.4.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:bf477c355274a72435ceb140dc42de0dc1e1e0bf6e97195be30487d8eaaf1a09"}, - {file = "coverage-7.4.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:83c2dda2666fe32332f8e87481eed056c8b4d163fe18ecc690b02802d36a4d26"}, - {file = "coverage-7.4.0-cp311-cp311-win32.whl", hash = "sha256:697d1317e5290a313ef0d369650cfee1a114abb6021fa239ca12b4849ebbd614"}, - {file = "coverage-7.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:26776ff6c711d9d835557ee453082025d871e30b3fd6c27fcef14733f67f0590"}, - {file = "coverage-7.4.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:13eaf476ec3e883fe3e5fe3707caeb88268a06284484a3daf8250259ef1ba143"}, - {file = "coverage-7.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:846f52f46e212affb5bcf131c952fb4075b55aae6b61adc9856222df89cbe3e2"}, - {file = "coverage-7.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:26f66da8695719ccf90e794ed567a1549bb2644a706b41e9f6eae6816b398c4a"}, - {file = "coverage-7.4.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:164fdcc3246c69a6526a59b744b62e303039a81e42cfbbdc171c91a8cc2f9446"}, - {file = "coverage-7.4.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:316543f71025a6565677d84bc4df2114e9b6a615aa39fb165d697dba06a54af9"}, - {file = "coverage-7.4.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bb1de682da0b824411e00a0d4da5a784ec6496b6850fdf8c865c1d68c0e318dd"}, - {file = "coverage-7.4.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:0e8d06778e8fbffccfe96331a3946237f87b1e1d359d7fbe8b06b96c95a5407a"}, - {file = "coverage-7.4.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a56de34db7b7ff77056a37aedded01b2b98b508227d2d0979d373a9b5d353daa"}, - {file = "coverage-7.4.0-cp312-cp312-win32.whl", hash = "sha256:51456e6fa099a8d9d91497202d9563a320513fcf59f33991b0661a4a6f2ad450"}, - {file = "coverage-7.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:cd3c1e4cb2ff0083758f09be0f77402e1bdf704adb7f89108007300a6da587d0"}, - {file = "coverage-7.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e9d1bf53c4c8de58d22e0e956a79a5b37f754ed1ffdbf1a260d9dcfa2d8a325e"}, - {file = "coverage-7.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:109f5985182b6b81fe33323ab4707011875198c41964f014579cf82cebf2bb85"}, - {file = "coverage-7.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cc9d4bc55de8003663ec94c2f215d12d42ceea128da8f0f4036235a119c88ac"}, - {file = "coverage-7.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cc6d65b21c219ec2072c1293c505cf36e4e913a3f936d80028993dd73c7906b1"}, - {file = "coverage-7.4.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a10a4920def78bbfff4eff8a05c51be03e42f1c3735be42d851f199144897ba"}, - {file = "coverage-7.4.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b8e99f06160602bc64da35158bb76c73522a4010f0649be44a4e167ff8555952"}, - {file = "coverage-7.4.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:7d360587e64d006402b7116623cebf9d48893329ef035278969fa3bbf75b697e"}, - {file = "coverage-7.4.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:29f3abe810930311c0b5d1a7140f6395369c3db1be68345638c33eec07535105"}, - {file = "coverage-7.4.0-cp38-cp38-win32.whl", hash = "sha256:5040148f4ec43644702e7b16ca864c5314ccb8ee0751ef617d49aa0e2d6bf4f2"}, - {file = "coverage-7.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:9864463c1c2f9cb3b5db2cf1ff475eed2f0b4285c2aaf4d357b69959941aa555"}, - {file = "coverage-7.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:936d38794044b26c99d3dd004d8af0035ac535b92090f7f2bb5aa9c8e2f5cd42"}, - {file = "coverage-7.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:799c8f873794a08cdf216aa5d0531c6a3747793b70c53f70e98259720a6fe2d7"}, - {file = "coverage-7.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e7defbb9737274023e2d7af02cac77043c86ce88a907c58f42b580a97d5bcca9"}, - {file = "coverage-7.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a1526d265743fb49363974b7aa8d5899ff64ee07df47dd8d3e37dcc0818f09ed"}, - {file = "coverage-7.4.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf635a52fc1ea401baf88843ae8708591aa4adff875e5c23220de43b1ccf575c"}, - {file = "coverage-7.4.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:756ded44f47f330666843b5781be126ab57bb57c22adbb07d83f6b519783b870"}, - {file = "coverage-7.4.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:0eb3c2f32dabe3a4aaf6441dde94f35687224dfd7eb2a7f47f3fd9428e421058"}, - {file = "coverage-7.4.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:bfd5db349d15c08311702611f3dccbef4b4e2ec148fcc636cf8739519b4a5c0f"}, - {file = "coverage-7.4.0-cp39-cp39-win32.whl", hash = "sha256:53d7d9158ee03956e0eadac38dfa1ec8068431ef8058fe6447043db1fb40d932"}, - {file = "coverage-7.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:cfd2a8b6b0d8e66e944d47cdec2f47c48fef2ba2f2dff5a9a75757f64172857e"}, - {file = "coverage-7.4.0-pp38.pp39.pp310-none-any.whl", hash = "sha256:c530833afc4707fe48524a44844493f36d8727f04dcce91fb978c414a8556cc6"}, - {file = "coverage-7.4.0.tar.gz", hash = "sha256:707c0f58cb1712b8809ece32b68996ee1e609f71bd14615bd8f87a1293cb610e"}, + {file = "coverage-7.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0be5efd5127542ef31f165de269f77560d6cdef525fffa446de6f7e9186cfb2"}, + {file = "coverage-7.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ccd341521be3d1b3daeb41960ae94a5e87abe2f46f17224ba5d6f2b8398016cf"}, + {file = "coverage-7.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09fa497a8ab37784fbb20ab699c246053ac294d13fc7eb40ec007a5043ec91f8"}, + {file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1a93009cb80730c9bca5d6d4665494b725b6e8e157c1cb7f2db5b4b122ea562"}, + {file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:690db6517f09336559dc0b5f55342df62370a48f5469fabf502db2c6d1cffcd2"}, + {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:09c3255458533cb76ef55da8cc49ffab9e33f083739c8bd4f58e79fecfe288f7"}, + {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8ce1415194b4a6bd0cdcc3a1dfbf58b63f910dcb7330fe15bdff542c56949f87"}, + {file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b91cbc4b195444e7e258ba27ac33769c41b94967919f10037e6355e998af255c"}, + {file = "coverage-7.4.4-cp310-cp310-win32.whl", hash = "sha256:598825b51b81c808cb6f078dcb972f96af96b078faa47af7dfcdf282835baa8d"}, + {file = "coverage-7.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:09ef9199ed6653989ebbcaacc9b62b514bb63ea2f90256e71fea3ed74bd8ff6f"}, + {file = "coverage-7.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0f9f50e7ef2a71e2fae92774c99170eb8304e3fdf9c8c3c7ae9bab3e7229c5cf"}, + {file = "coverage-7.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:623512f8ba53c422fcfb2ce68362c97945095b864cda94a92edbaf5994201083"}, + {file = "coverage-7.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0513b9508b93da4e1716744ef6ebc507aff016ba115ffe8ecff744d1322a7b63"}, + {file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40209e141059b9370a2657c9b15607815359ab3ef9918f0196b6fccce8d3230f"}, + {file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a2b2b78c78293782fd3767d53e6474582f62443d0504b1554370bde86cc8227"}, + {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:73bfb9c09951125d06ee473bed216e2c3742f530fc5acc1383883125de76d9cd"}, + {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1f384c3cc76aeedce208643697fb3e8437604b512255de6d18dae3f27655a384"}, + {file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:54eb8d1bf7cacfbf2a3186019bcf01d11c666bd495ed18717162f7eb1e9dd00b"}, + {file = "coverage-7.4.4-cp311-cp311-win32.whl", hash = "sha256:cac99918c7bba15302a2d81f0312c08054a3359eaa1929c7e4b26ebe41e9b286"}, + {file = "coverage-7.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:b14706df8b2de49869ae03a5ccbc211f4041750cd4a66f698df89d44f4bd30ec"}, + {file = "coverage-7.4.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:201bef2eea65e0e9c56343115ba3814e896afe6d36ffd37bab783261db430f76"}, + {file = "coverage-7.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:41c9c5f3de16b903b610d09650e5e27adbfa7f500302718c9ffd1c12cf9d6818"}, + {file = "coverage-7.4.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d898fe162d26929b5960e4e138651f7427048e72c853607f2b200909794ed978"}, + {file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ea79bb50e805cd6ac058dfa3b5c8f6c040cb87fe83de10845857f5535d1db70"}, + {file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce4b94265ca988c3f8e479e741693d143026632672e3ff924f25fab50518dd51"}, + {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:00838a35b882694afda09f85e469c96367daa3f3f2b097d846a7216993d37f4c"}, + {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fdfafb32984684eb03c2d83e1e51f64f0906b11e64482df3c5db936ce3839d48"}, + {file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:69eb372f7e2ece89f14751fbcbe470295d73ed41ecd37ca36ed2eb47512a6ab9"}, + {file = "coverage-7.4.4-cp312-cp312-win32.whl", hash = "sha256:137eb07173141545e07403cca94ab625cc1cc6bc4c1e97b6e3846270e7e1fea0"}, + {file = "coverage-7.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:d71eec7d83298f1af3326ce0ff1d0ea83c7cb98f72b577097f9083b20bdaf05e"}, + {file = "coverage-7.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d5ae728ff3b5401cc320d792866987e7e7e880e6ebd24433b70a33b643bb0384"}, + {file = "coverage-7.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc4f1358cb0c78edef3ed237ef2c86056206bb8d9140e73b6b89fbcfcbdd40e1"}, + {file = "coverage-7.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8130a2aa2acb8788e0b56938786c33c7c98562697bf9f4c7d6e8e5e3a0501e4a"}, + {file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf271892d13e43bc2b51e6908ec9a6a5094a4df1d8af0bfc360088ee6c684409"}, + {file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4cdc86d54b5da0df6d3d3a2f0b710949286094c3a6700c21e9015932b81447e"}, + {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ae71e7ddb7a413dd60052e90528f2f65270aad4b509563af6d03d53e979feafd"}, + {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:38dd60d7bf242c4ed5b38e094baf6401faa114fc09e9e6632374388a404f98e7"}, + {file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa5b1c1bfc28384f1f53b69a023d789f72b2e0ab1b3787aae16992a7ca21056c"}, + {file = "coverage-7.4.4-cp38-cp38-win32.whl", hash = "sha256:dfa8fe35a0bb90382837b238fff375de15f0dcdb9ae68ff85f7a63649c98527e"}, + {file = "coverage-7.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:b2991665420a803495e0b90a79233c1433d6ed77ef282e8e152a324bbbc5e0c8"}, + {file = "coverage-7.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3b799445b9f7ee8bf299cfaed6f5b226c0037b74886a4e11515e569b36fe310d"}, + {file = "coverage-7.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b4d33f418f46362995f1e9d4f3a35a1b6322cb959c31d88ae56b0298e1c22357"}, + {file = "coverage-7.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aadacf9a2f407a4688d700e4ebab33a7e2e408f2ca04dbf4aef17585389eff3e"}, + {file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c95949560050d04d46b919301826525597f07b33beba6187d04fa64d47ac82e"}, + {file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff7687ca3d7028d8a5f0ebae95a6e4827c5616b31a4ee1192bdfde697db110d4"}, + {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5fc1de20b2d4a061b3df27ab9b7c7111e9a710f10dc2b84d33a4ab25065994ec"}, + {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c74880fc64d4958159fbd537a091d2a585448a8f8508bf248d72112723974cbd"}, + {file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:742a76a12aa45b44d236815d282b03cfb1de3b4323f3e4ec933acfae08e54ade"}, + {file = "coverage-7.4.4-cp39-cp39-win32.whl", hash = "sha256:d89d7b2974cae412400e88f35d86af72208e1ede1a541954af5d944a8ba46c57"}, + {file = "coverage-7.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:9ca28a302acb19b6af89e90f33ee3e1906961f94b54ea37de6737b7ca9d8827c"}, + {file = "coverage-7.4.4-pp38.pp39.pp310-none-any.whl", hash = "sha256:b2c5edc4ac10a7ef6605a966c58929ec6c1bd0917fb8c15cb3363f65aa40e677"}, + {file = "coverage-7.4.4.tar.gz", hash = "sha256:c901df83d097649e257e803be22592aedfd5182f07b3cc87d640bbb9afd50f49"}, ] [package.dependencies] @@ -789,71 +779,77 @@ toml = ["tomli"] [[package]] name = "datasets" -version = "2.14.4" +version = "2.18.0" description = "HuggingFace community-driven open-source library of datasets" optional = false python-versions = ">=3.8.0" files = [ - {file = "datasets-2.14.4-py3-none-any.whl", hash = "sha256:29336bd316a7d827ccd4da2236596279b20ca2ac78f64c04c9483da7cbc2459b"}, - {file = "datasets-2.14.4.tar.gz", hash = "sha256:ef29c2b5841de488cd343cfc26ab979bff77efa4d2285af51f1ad7db5c46a83b"}, + {file = "datasets-2.18.0-py3-none-any.whl", hash = "sha256:f1bbf0e2896917a914de01cbd37075b14deea3837af87ad0d9f697388ccaeb50"}, + {file = "datasets-2.18.0.tar.gz", hash = "sha256:cdf8b8c6abf7316377ba4f49f9589a4c74556d6b481afd0abd2284f3d69185cb"}, ] [package.dependencies] aiohttp = "*" -dill = ">=0.3.0,<0.3.8" -fsspec = {version = ">=2021.11.1", extras = ["http"]} -huggingface-hub = ">=0.14.0,<1.0.0" +dill = ">=0.3.0,<0.3.9" +filelock = "*" +fsspec = {version = ">=2023.1.0,<=2024.2.0", extras = ["http"]} +huggingface-hub = ">=0.19.4" multiprocess = "*" numpy = ">=1.17" packaging = "*" pandas = "*" -pyarrow = ">=8.0.0" +pyarrow = ">=12.0.0" +pyarrow-hotfix = "*" pyyaml = ">=5.1" requests = ">=2.19.0" tqdm = ">=4.62.1" xxhash = "*" [package.extras] -apache-beam = ["apache-beam (>=2.26.0,<2.44.0)"] +apache-beam = ["apache-beam (>=2.26.0)"] audio = ["librosa", "soundfile (>=0.12.1)"] benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"] -dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "black (>=23.1,<24.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "pyyaml (>=5.3.1)", "rarfile (>=4.0)", "ruff (>=0.0.241)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] +dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] docs = ["s3fs", "tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos", "torch", "transformers"] -jax = ["jax (>=0.2.8,!=0.3.2,<=0.3.25)", "jaxlib (>=0.1.65,<=0.3.25)"] +jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"] metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"] -quality = ["black (>=23.1,<24.0)", "pyyaml (>=5.3.1)", "ruff (>=0.0.241)"] +quality = ["ruff (>=0.3.0)"] s3 = ["s3fs"] tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)", "tensorflow-macos"] tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] -tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy (<2.0.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch", "transformers", "zstandard"] +tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "tensorflow-macos", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] torch = ["torch"] vision = ["Pillow (>=6.2.1)"] [[package]] name = "debugpy" -version = "1.8.0" +version = "1.8.1" description = "An implementation of the Debug Adapter Protocol for Python" optional = false python-versions = ">=3.8" files = [ - {file = "debugpy-1.8.0-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:7fb95ca78f7ac43393cd0e0f2b6deda438ec7c5e47fa5d38553340897d2fbdfb"}, - {file = "debugpy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef9ab7df0b9a42ed9c878afd3eaaff471fce3fa73df96022e1f5c9f8f8c87ada"}, - {file = "debugpy-1.8.0-cp310-cp310-win32.whl", hash = "sha256:a8b7a2fd27cd9f3553ac112f356ad4ca93338feadd8910277aff71ab24d8775f"}, - {file = "debugpy-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:5d9de202f5d42e62f932507ee8b21e30d49aae7e46d5b1dd5c908db1d7068637"}, - {file = "debugpy-1.8.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:ef54404365fae8d45cf450d0544ee40cefbcb9cb85ea7afe89a963c27028261e"}, - {file = "debugpy-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:60009b132c91951354f54363f8ebdf7457aeb150e84abba5ae251b8e9f29a8a6"}, - {file = "debugpy-1.8.0-cp311-cp311-win32.whl", hash = "sha256:8cd0197141eb9e8a4566794550cfdcdb8b3db0818bdf8c49a8e8f8053e56e38b"}, - {file = "debugpy-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:a64093656c4c64dc6a438e11d59369875d200bd5abb8f9b26c1f5f723622e153"}, - {file = "debugpy-1.8.0-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:b05a6b503ed520ad58c8dc682749113d2fd9f41ffd45daec16e558ca884008cd"}, - {file = "debugpy-1.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c6fb41c98ec51dd010d7ed650accfd07a87fe5e93eca9d5f584d0578f28f35f"}, - {file = "debugpy-1.8.0-cp38-cp38-win32.whl", hash = "sha256:46ab6780159eeabb43c1495d9c84cf85d62975e48b6ec21ee10c95767c0590aa"}, - {file = "debugpy-1.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:bdc5ef99d14b9c0fcb35351b4fbfc06ac0ee576aeab6b2511702e5a648a2e595"}, - {file = "debugpy-1.8.0-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:61eab4a4c8b6125d41a34bad4e5fe3d2cc145caecd63c3fe953be4cc53e65bf8"}, - {file = "debugpy-1.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:125b9a637e013f9faac0a3d6a82bd17c8b5d2c875fb6b7e2772c5aba6d082332"}, - {file = "debugpy-1.8.0-cp39-cp39-win32.whl", hash = "sha256:57161629133113c97b387382045649a2b985a348f0c9366e22217c87b68b73c6"}, - {file = "debugpy-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:e3412f9faa9ade82aa64a50b602544efcba848c91384e9f93497a458767e6926"}, - {file = "debugpy-1.8.0-py2.py3-none-any.whl", hash = "sha256:9c9b0ac1ce2a42888199df1a1906e45e6f3c9555497643a85e0bf2406e3ffbc4"}, - {file = "debugpy-1.8.0.zip", hash = "sha256:12af2c55b419521e33d5fb21bd022df0b5eb267c3e178f1d374a63a2a6bdccd0"}, + {file = "debugpy-1.8.1-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:3bda0f1e943d386cc7a0e71bfa59f4137909e2ed947fb3946c506e113000f741"}, + {file = "debugpy-1.8.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dda73bf69ea479c8577a0448f8c707691152e6c4de7f0c4dec5a4bc11dee516e"}, + {file = "debugpy-1.8.1-cp310-cp310-win32.whl", hash = "sha256:3a79c6f62adef994b2dbe9fc2cc9cc3864a23575b6e387339ab739873bea53d0"}, + {file = "debugpy-1.8.1-cp310-cp310-win_amd64.whl", hash = "sha256:7eb7bd2b56ea3bedb009616d9e2f64aab8fc7000d481faec3cd26c98a964bcdd"}, + {file = "debugpy-1.8.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:016a9fcfc2c6b57f939673c874310d8581d51a0fe0858e7fac4e240c5eb743cb"}, + {file = "debugpy-1.8.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd97ed11a4c7f6d042d320ce03d83b20c3fb40da892f994bc041bbc415d7a099"}, + {file = "debugpy-1.8.1-cp311-cp311-win32.whl", hash = "sha256:0de56aba8249c28a300bdb0672a9b94785074eb82eb672db66c8144fff673146"}, + {file = "debugpy-1.8.1-cp311-cp311-win_amd64.whl", hash = "sha256:1a9fe0829c2b854757b4fd0a338d93bc17249a3bf69ecf765c61d4c522bb92a8"}, + {file = "debugpy-1.8.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3ebb70ba1a6524d19fa7bb122f44b74170c447d5746a503e36adc244a20ac539"}, + {file = "debugpy-1.8.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2e658a9630f27534e63922ebf655a6ab60c370f4d2fc5c02a5b19baf4410ace"}, + {file = "debugpy-1.8.1-cp312-cp312-win32.whl", hash = "sha256:caad2846e21188797a1f17fc09c31b84c7c3c23baf2516fed5b40b378515bbf0"}, + {file = "debugpy-1.8.1-cp312-cp312-win_amd64.whl", hash = "sha256:edcc9f58ec0fd121a25bc950d4578df47428d72e1a0d66c07403b04eb93bcf98"}, + {file = "debugpy-1.8.1-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:7a3afa222f6fd3d9dfecd52729bc2e12c93e22a7491405a0ecbf9e1d32d45b39"}, + {file = "debugpy-1.8.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d915a18f0597ef685e88bb35e5d7ab968964b7befefe1aaea1eb5b2640b586c7"}, + {file = "debugpy-1.8.1-cp38-cp38-win32.whl", hash = "sha256:92116039b5500633cc8d44ecc187abe2dfa9b90f7a82bbf81d079fcdd506bae9"}, + {file = "debugpy-1.8.1-cp38-cp38-win_amd64.whl", hash = "sha256:e38beb7992b5afd9d5244e96ad5fa9135e94993b0c551ceebf3fe1a5d9beb234"}, + {file = "debugpy-1.8.1-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:bfb20cb57486c8e4793d41996652e5a6a885b4d9175dd369045dad59eaacea42"}, + {file = "debugpy-1.8.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efd3fdd3f67a7e576dd869c184c5dd71d9aaa36ded271939da352880c012e703"}, + {file = "debugpy-1.8.1-cp39-cp39-win32.whl", hash = "sha256:58911e8521ca0c785ac7a0539f1e77e0ce2df753f786188f382229278b4cdf23"}, + {file = "debugpy-1.8.1-cp39-cp39-win_amd64.whl", hash = "sha256:6df9aa9599eb05ca179fb0b810282255202a66835c6efb1d112d21ecb830ddd3"}, + {file = "debugpy-1.8.1-py2.py3-none-any.whl", hash = "sha256:28acbe2241222b87e255260c76741e1fbf04fdc3b6d094fcf57b6c6f75ce1242"}, + {file = "debugpy-1.8.1.zip", hash = "sha256:f696d6be15be87aef621917585f9bb94b1dc9e8aced570db1b8a6fc14e8f9b42"}, ] [[package]] @@ -880,17 +876,18 @@ files = [ [[package]] name = "dill" -version = "0.3.7" +version = "0.3.8" description = "serialize all of Python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "dill-0.3.7-py3-none-any.whl", hash = "sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e"}, - {file = "dill-0.3.7.tar.gz", hash = "sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03"}, + {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"}, + {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"}, ] [package.extras] graph = ["objgraph (>=1.7.2)"] +profile = ["gprof2dot (>=2022.7.29)"] [[package]] name = "docker-pycreds" @@ -983,18 +980,18 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc [[package]] name = "filelock" -version = "3.13.1" +version = "3.13.3" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, - {file = "filelock-3.13.1.tar.gz", hash = "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e"}, + {file = "filelock-3.13.3-py3-none-any.whl", hash = "sha256:5ffa845303983e7a0b7ae17636509bc97997d58afeafa72fb141a17b152284cb"}, + {file = "filelock-3.13.3.tar.gz", hash = "sha256:a79895a25bbefdf55d1a2a0a80968f7dbb28edcd6d4234a0afb3f37ecde4b546"}, ] [package.extras] -docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.24)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] typing = ["typing-extensions (>=4.8)"] [[package]] @@ -1096,18 +1093,17 @@ files = [ [[package]] name = "fsspec" -version = "2023.12.2" +version = "2024.2.0" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2023.12.2-py3-none-any.whl", hash = "sha256:d800d87f72189a745fa3d6b033b9dc4a34ad069f60ca60b943a63599f5501960"}, - {file = "fsspec-2023.12.2.tar.gz", hash = "sha256:8548d39e8810b59c38014934f6b31e57f40c1b20f911f4cc2b85389c7e9bf0cb"}, + {file = "fsspec-2024.2.0-py3-none-any.whl", hash = "sha256:817f969556fa5916bc682e02ca2045f96ff7f586d45110fcb76022063ad2c7d8"}, + {file = "fsspec-2024.2.0.tar.gz", hash = "sha256:b6ad1a679f760dda52b1168c859d01b7b80648ea6f7f7c7f5a8a91dc3f3ecb84"}, ] [package.dependencies] aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""} -requests = {version = "*", optional = true, markers = "extra == \"http\""} [package.extras] abfs = ["adlfs"] @@ -1124,7 +1120,7 @@ github = ["requests"] gs = ["gcsfs"] gui = ["panel"] hdfs = ["pyarrow (>=1)"] -http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "requests"] +http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)"] libarchive = ["libarchive-c"] oci = ["ocifs"] s3 = ["s3fs"] @@ -1166,30 +1162,86 @@ smmap = ">=3.0.1,<6" [[package]] name = "gitpython" -version = "3.1.40" +version = "3.1.42" description = "GitPython is a Python library used to interact with Git repositories" optional = false python-versions = ">=3.7" files = [ - {file = "GitPython-3.1.40-py3-none-any.whl", hash = "sha256:cf14627d5a8049ffbf49915732e5eddbe8134c3bdb9d476e6182b676fc573f8a"}, - {file = "GitPython-3.1.40.tar.gz", hash = "sha256:22b126e9ffb671fdd0c129796343a02bf67bf2994b35449ffc9321aa755e18a4"}, + {file = "GitPython-3.1.42-py3-none-any.whl", hash = "sha256:1bf9cd7c9e7255f77778ea54359e54ac22a72a5b51288c457c881057b7bb9ecd"}, + {file = "GitPython-3.1.42.tar.gz", hash = "sha256:2d99869e0fef71a73cbd242528105af1d6c1b108c60dfabd994bf292f76c3ceb"}, ] [package.dependencies] gitdb = ">=4.0.1,<5" [package.extras] -test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-instafail", "pytest-subtests", "pytest-sugar"] +test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar"] + +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + +[[package]] +name = "httpcore" +version = "1.0.5" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, + {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.26.0)"] + +[[package]] +name = "httpx" +version = "0.27.0" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, + {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] [[package]] name = "huggingface-hub" -version = "0.20.2" +version = "0.22.1" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.20.2-py3-none-any.whl", hash = "sha256:53752eda2239d30a470c307a61cf9adcf136bc77b0a734338c7d04941af560d8"}, - {file = "huggingface_hub-0.20.2.tar.gz", hash = "sha256:215c5fceff631030c7a3d19ba7b588921c908b3f21eef31d160ebc245b200ff6"}, + {file = "huggingface_hub-0.22.1-py3-none-any.whl", hash = "sha256:eac63947923d15c9a68681d7ed2d9599e058860617064e3ee6bd91a4b954faaf"}, + {file = "huggingface_hub-0.22.1.tar.gz", hash = "sha256:5b8aaee5f3618cd432f49886da9935bbe8fab92d719011826430907b93171dd8"}, ] [package.dependencies] @@ -1202,15 +1254,17 @@ tqdm = ">=4.42.1" typing-extensions = ">=3.7.4.3" [package.extras] -all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] cli = ["InquirerPy (==0.3.4)"] -dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] -inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"] -quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"] +hf-transfer = ["hf-transfer (>=0.1.4)"] +inference = ["aiohttp", "minijinja (>=1.0)"] +quality = ["mypy (==1.5.1)", "ruff (>=0.3.0)"] tensorflow = ["graphviz", "pydot", "tensorflow"] -testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] -torch = ["torch"] +tensorflow-testing = ["keras (<3.0)", "tensorflow"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +torch = ["safetensors", "torch"] typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] [[package]] @@ -1237,32 +1291,32 @@ files = [ [[package]] name = "importlib-metadata" -version = "7.0.1" +version = "5.2.0" description = "Read metadata from Python packages" optional = false -python-versions = ">=3.8" +python-versions = ">=3.7" files = [ - {file = "importlib_metadata-7.0.1-py3-none-any.whl", hash = "sha256:4805911c3a4ec7c3966410053e9ec6a1fecd629117df5adee56dfc9432a1081e"}, - {file = "importlib_metadata-7.0.1.tar.gz", hash = "sha256:f238736bb06590ae52ac1fab06a3a9ef1d8dce2b7a35b5ab329371d6c8f5d2cc"}, + {file = "importlib_metadata-5.2.0-py3-none-any.whl", hash = "sha256:0eafa39ba42bf225fc00e67f701d71f85aead9f878569caf13c3724f704b970f"}, + {file = "importlib_metadata-5.2.0.tar.gz", hash = "sha256:404d48d62bba0b7a77ff9d405efd91501bef2e67ff4ace0bed40a0cf28c3c7cd"}, ] [package.dependencies] zipp = ">=0.5" [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] +docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] +testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)"] [[package]] name = "importlib-resources" -version = "6.1.1" +version = "6.4.0" description = "Read resources from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_resources-6.1.1-py3-none-any.whl", hash = "sha256:e8bf90d8213b486f428c9c39714b920041cb02c184686a3dee24905aaa8105d6"}, - {file = "importlib_resources-6.1.1.tar.gz", hash = "sha256:3893a00122eafde6894c59914446a512f728a0c1a45f9bb9b63721b6bacf0b4a"}, + {file = "importlib_resources-6.4.0-py3-none-any.whl", hash = "sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c"}, + {file = "importlib_resources-6.4.0.tar.gz", hash = "sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145"}, ] [package.dependencies] @@ -1270,7 +1324,7 @@ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-ruff", "zipp (>=3.17)"] +testing = ["jaraco.test (>=5.4)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] [[package]] name = "iniconfig" @@ -1285,13 +1339,13 @@ files = [ [[package]] name = "ipykernel" -version = "6.28.0" +version = "6.29.4" description = "IPython Kernel for Jupyter" optional = false python-versions = ">=3.8" files = [ - {file = "ipykernel-6.28.0-py3-none-any.whl", hash = "sha256:c6e9a9c63a7f4095c0a22a79f765f079f9ec7be4f2430a898ddea889e8665661"}, - {file = "ipykernel-6.28.0.tar.gz", hash = "sha256:69c11403d26de69df02225916f916b37ea4b9af417da0a8c827f84328d88e5f3"}, + {file = "ipykernel-6.29.4-py3-none-any.whl", hash = "sha256:1181e653d95c6808039c509ef8e67c4126b3b3af7781496c7cbfb5ed938a27da"}, + {file = "ipykernel-6.29.4.tar.gz", hash = "sha256:3d44070060f9475ac2092b760123fadf105d2e2493c24848b6691a7c4f42af5c"}, ] [package.dependencies] @@ -1314,7 +1368,7 @@ cov = ["coverage[toml]", "curio", "matplotlib", "pytest-cov", "trio"] docs = ["myst-parser", "pydata-sphinx-theme", "sphinx", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "trio"] pyqt5 = ["pyqt5"] pyside6 = ["pyside6"] -test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio", "pytest-cov", "pytest-timeout"] +test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0)", "pytest-asyncio (>=0.23.5)", "pytest-cov", "pytest-timeout"] [[package]] name = "ipython" @@ -1357,21 +1411,21 @@ test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.21)", "pa [[package]] name = "ipywidgets" -version = "8.1.1" +version = "8.1.2" description = "Jupyter interactive widgets" optional = false python-versions = ">=3.7" files = [ - {file = "ipywidgets-8.1.1-py3-none-any.whl", hash = "sha256:2b88d728656aea3bbfd05d32c747cfd0078f9d7e159cf982433b58ad717eed7f"}, - {file = "ipywidgets-8.1.1.tar.gz", hash = "sha256:40211efb556adec6fa450ccc2a77d59ca44a060f4f9f136833df59c9f538e6e8"}, + {file = "ipywidgets-8.1.2-py3-none-any.whl", hash = "sha256:bbe43850d79fb5e906b14801d6c01402857996864d1e5b6fa62dd2ee35559f60"}, + {file = "ipywidgets-8.1.2.tar.gz", hash = "sha256:d0b9b41e49bae926a866e613a39b0f0097745d2b9f1f3dd406641b4a57ec42c9"}, ] [package.dependencies] comm = ">=0.1.3" ipython = ">=6.1.0" -jupyterlab-widgets = ">=3.0.9,<3.1.0" +jupyterlab-widgets = ">=3.0.10,<3.1.0" traitlets = ">=4.3.1" -widgetsnbextension = ">=4.0.9,<4.1.0" +widgetsnbextension = ">=4.0.10,<4.1.0" [package.extras] test = ["ipykernel", "jsonschema", "pytest (>=3.6.0)", "pytest-cov", "pytz"] @@ -1443,13 +1497,13 @@ testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] [[package]] name = "jinja2" -version = "3.1.2" +version = "3.1.3" description = "A very fast and expressive template engine." optional = false python-versions = ">=3.7" files = [ - {file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"}, - {file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"}, + {file = "Jinja2-3.1.3-py3-none-any.whl", hash = "sha256:7d6d50dd97d52cbc355597bd845fabfbac3f551e1f99619e39a35ce8c370b5fa"}, + {file = "Jinja2-3.1.3.tar.gz", hash = "sha256:ac8bd6544d4bb2c9792bf3a159e80bba8fda7f07e81bc3aed565432d5925ba90"}, ] [package.dependencies] @@ -1460,18 +1514,15 @@ i18n = ["Babel (>=2.7)"] [[package]] name = "json5" -version = "0.9.14" +version = "0.9.24" description = "A Python implementation of the JSON5 data format." optional = false -python-versions = "*" +python-versions = ">=3.8" files = [ - {file = "json5-0.9.14-py2.py3-none-any.whl", hash = "sha256:740c7f1b9e584a468dbb2939d8d458db3427f2c93ae2139d05f47e453eae964f"}, - {file = "json5-0.9.14.tar.gz", hash = "sha256:9ed66c3a6ca3510a976a9ef9b8c0787de24802724ab1860bc0153c7fdd589b02"}, + {file = "json5-0.9.24-py3-none-any.whl", hash = "sha256:4ca101fd5c7cb47960c055ef8f4d0e31e15a7c6c48c3b6f1473fc83b6c462a13"}, + {file = "json5-0.9.24.tar.gz", hash = "sha256:0c638399421da959a20952782800e5c1a78c14e08e1dc9738fa10d8ec14d58c8"}, ] -[package.extras] -dev = ["hypothesis"] - [[package]] name = "jsonpointer" version = "2.4" @@ -1485,13 +1536,13 @@ files = [ [[package]] name = "jsonschema" -version = "4.20.0" +version = "4.21.1" description = "An implementation of JSON Schema validation for Python" optional = false python-versions = ">=3.8" files = [ - {file = "jsonschema-4.20.0-py3-none-any.whl", hash = "sha256:ed6231f0429ecf966f5bc8dfef245998220549cbbcf140f913b7464c52c3b6b3"}, - {file = "jsonschema-4.20.0.tar.gz", hash = "sha256:4f614fd46d8d61258610998997743ec5492a648b33cf478c1ddc23ed4598a5fa"}, + {file = "jsonschema-4.21.1-py3-none-any.whl", hash = "sha256:7996507afae316306f9e2290407761157c6f78002dcf7419acb99822143d1c6f"}, + {file = "jsonschema-4.21.1.tar.gz", hash = "sha256:85727c00279f5fa6bedbe6238d2aa6403bedd8b4864ab11207d07df3cc1b2ee5"}, ] [package.dependencies] @@ -1551,13 +1602,13 @@ qtconsole = "*" [[package]] name = "jupyter-client" -version = "8.6.0" +version = "8.6.1" description = "Jupyter protocol implementation and client libraries" optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_client-8.6.0-py3-none-any.whl", hash = "sha256:909c474dbe62582ae62b758bca86d6518c85234bdee2d908c778db6d72f39d99"}, - {file = "jupyter_client-8.6.0.tar.gz", hash = "sha256:0642244bb83b4764ae60d07e010e15f0e2d275ec4e918a8f7b80fbbef3ca60c7"}, + {file = "jupyter_client-8.6.1-py3-none-any.whl", hash = "sha256:3b7bd22f058434e3b9a7ea4b1500ed47de2713872288c0d511d19926f99b459f"}, + {file = "jupyter_client-8.6.1.tar.gz", hash = "sha256:e842515e2bab8e19186d89fdfea7abd15e39dd581f94e399f00e2af5a1652d3f"}, ] [package.dependencies] @@ -1598,13 +1649,13 @@ test = ["flaky", "pexpect", "pytest"] [[package]] name = "jupyter-core" -version = "5.7.1" +version = "5.7.2" description = "Jupyter core package. A base package on which Jupyter projects rely." optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_core-5.7.1-py3-none-any.whl", hash = "sha256:c65c82126453a723a2804aa52409930434598fd9d35091d63dfb919d2b765bb7"}, - {file = "jupyter_core-5.7.1.tar.gz", hash = "sha256:de61a9d7fc71240f688b2fb5ab659fbb56979458dc66a71decd098e03c79e218"}, + {file = "jupyter_core-5.7.2-py3-none-any.whl", hash = "sha256:4f7315d2f6b4bcf2e3e7cb6e46772eba760ae459cd1f59d29eb57b0a01bd7409"}, + {file = "jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9"}, ] [package.dependencies] @@ -1614,17 +1665,17 @@ traitlets = ">=5.3" [package.extras] docs = ["myst-parser", "pydata-sphinx-theme", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"] -test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] +test = ["ipykernel", "pre-commit", "pytest (<8)", "pytest-cov", "pytest-timeout"] [[package]] name = "jupyter-events" -version = "0.9.0" +version = "0.10.0" description = "Jupyter Event System library" optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_events-0.9.0-py3-none-any.whl", hash = "sha256:d853b3c10273ff9bc8bb8b30076d65e2c9685579db736873de6c2232dde148bf"}, - {file = "jupyter_events-0.9.0.tar.gz", hash = "sha256:81ad2e4bc710881ec274d31c6c50669d71bbaa5dd9d01e600b56faa85700d399"}, + {file = "jupyter_events-0.10.0-py3-none-any.whl", hash = "sha256:4b72130875e59d57716d327ea70d3ebc3af1944d3717e5a498b8a06c6c159960"}, + {file = "jupyter_events-0.10.0.tar.gz", hash = "sha256:670b8229d3cc882ec782144ed22e0d29e1c2d639263f92ca8383e66682845e22"}, ] [package.dependencies] @@ -1643,13 +1694,13 @@ test = ["click", "pre-commit", "pytest (>=7.0)", "pytest-asyncio (>=0.19.0)", "p [[package]] name = "jupyter-lsp" -version = "2.2.1" +version = "2.2.4" description = "Multi-Language Server WebSocket proxy for Jupyter Notebook/Lab server" optional = false python-versions = ">=3.8" files = [ - {file = "jupyter-lsp-2.2.1.tar.gz", hash = "sha256:b17fab6d70fe83c8896b0cff59237640038247c196056b43684a0902b6a9e0fb"}, - {file = "jupyter_lsp-2.2.1-py3-none-any.whl", hash = "sha256:17a689910c5e4ae5e7d334b02f31d08ffbe98108f6f658fb05e4304b4345368b"}, + {file = "jupyter-lsp-2.2.4.tar.gz", hash = "sha256:5e50033149344065348e688608f3c6d654ef06d9856b67655bd7b6bac9ee2d59"}, + {file = "jupyter_lsp-2.2.4-py3-none-any.whl", hash = "sha256:da61cb63a16b6dff5eac55c2699cc36eac975645adee02c41bdfc03bf4802e77"}, ] [package.dependencies] @@ -1658,13 +1709,13 @@ jupyter-server = ">=1.1.2" [[package]] name = "jupyter-server" -version = "2.12.3" +version = "2.13.0" description = "The backend—i.e. core services, APIs, and REST endpoints—to Jupyter web applications." optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_server-2.12.3-py3-none-any.whl", hash = "sha256:6f85310ea5e6068568a521f079fba99d8d17e4884dd1d602ab0f43b3115204a8"}, - {file = "jupyter_server-2.12.3.tar.gz", hash = "sha256:a1d2d51e497b1a6256c48b6940b0dd49b2553981baf1690077c37792f1fa23a1"}, + {file = "jupyter_server-2.13.0-py3-none-any.whl", hash = "sha256:77b2b49c3831fbbfbdb5048cef4350d12946191f833a24e5f83e5f8f4803e97b"}, + {file = "jupyter_server-2.13.0.tar.gz", hash = "sha256:c80bfb049ea20053c3d9641c2add4848b38073bf79f1729cea1faed32fc1c78e"}, ] [package.dependencies] @@ -1690,17 +1741,17 @@ websocket-client = "*" [package.extras] docs = ["ipykernel", "jinja2", "jupyter-client", "jupyter-server", "myst-parser", "nbformat", "prometheus-client", "pydata-sphinx-theme", "send2trash", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-openapi (>=0.8.0)", "sphinxcontrib-spelling", "sphinxemoji", "tornado", "typing-extensions"] -test = ["flaky", "ipykernel", "pre-commit", "pytest (>=7.0)", "pytest-console-scripts", "pytest-jupyter[server] (>=0.4)", "pytest-timeout", "requests"] +test = ["flaky", "ipykernel", "pre-commit", "pytest (>=7.0)", "pytest-console-scripts", "pytest-jupyter[server] (>=0.7)", "pytest-timeout", "requests"] [[package]] name = "jupyter-server-terminals" -version = "0.5.1" +version = "0.5.3" description = "A Jupyter Server Extension Providing Terminals." optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_server_terminals-0.5.1-py3-none-any.whl", hash = "sha256:5e63e947ddd97bb2832db5ef837a258d9ccd4192cd608c1270850ad947ae5dd7"}, - {file = "jupyter_server_terminals-0.5.1.tar.gz", hash = "sha256:16d3be9cf48be6a1f943f3a6c93c033be259cf4779184c66421709cf63dccfea"}, + {file = "jupyter_server_terminals-0.5.3-py3-none-any.whl", hash = "sha256:41ee0d7dc0ebf2809c668e0fc726dfaf258fcd3e769568996ca731b6194ae9aa"}, + {file = "jupyter_server_terminals-0.5.3.tar.gz", hash = "sha256:5ae0295167220e9ace0edcfdb212afd2b01ee8d179fe6f23c899590e9b8a5269"}, ] [package.dependencies] @@ -1713,17 +1764,18 @@ test = ["jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-jupyter[server] (> [[package]] name = "jupyterlab" -version = "4.0.10" +version = "4.1.5" description = "JupyterLab computational environment" optional = false python-versions = ">=3.8" files = [ - {file = "jupyterlab-4.0.10-py3-none-any.whl", hash = "sha256:fe010ad9e37017488b468632ef2ead255fc7c671c5b64d9ca13e1f7b7e665c37"}, - {file = "jupyterlab-4.0.10.tar.gz", hash = "sha256:46177eb8ede70dc73be922ac99f8ef943bdc2dfbc6a31b353c4bde848a35dee1"}, + {file = "jupyterlab-4.1.5-py3-none-any.whl", hash = "sha256:3bc843382a25e1ab7bc31d9e39295a9f0463626692b7995597709c0ab236ab2c"}, + {file = "jupyterlab-4.1.5.tar.gz", hash = "sha256:c9ad75290cb10bfaff3624bf3fbb852319b4cce4c456613f8ebbaa98d03524db"}, ] [package.dependencies] async-lru = ">=1.0.0" +httpx = ">=0.25.0" importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""} importlib-resources = {version = ">=1.4", markers = "python_version < \"3.9\""} ipykernel = "*" @@ -1739,9 +1791,9 @@ tornado = ">=6.2.0" traitlets = "*" [package.extras] -dev = ["build", "bump2version", "coverage", "hatch", "pre-commit", "pytest-cov", "ruff (==0.1.6)"] -docs = ["jsx-lexer", "myst-parser", "pydata-sphinx-theme (>=0.13.0)", "pytest", "pytest-check-links", "pytest-tornasync", "sphinx (>=1.8,<7.2.0)", "sphinx-copybutton"] -docs-screenshots = ["altair (==5.0.1)", "ipython (==8.14.0)", "ipywidgets (==8.0.6)", "jupyterlab-geojson (==3.4.0)", "jupyterlab-language-pack-zh-cn (==4.0.post0)", "matplotlib (==3.7.1)", "nbconvert (>=7.0.0)", "pandas (==2.0.2)", "scipy (==1.10.1)", "vega-datasets (==0.9.0)"] +dev = ["build", "bump2version", "coverage", "hatch", "pre-commit", "pytest-cov", "ruff (==0.2.0)"] +docs = ["jsx-lexer", "myst-parser", "pydata-sphinx-theme (>=0.13.0)", "pytest", "pytest-check-links", "pytest-jupyter", "sphinx (>=1.8,<7.3.0)", "sphinx-copybutton"] +docs-screenshots = ["altair (==5.2.0)", "ipython (==8.16.1)", "ipywidgets (==8.1.1)", "jupyterlab-geojson (==3.4.0)", "jupyterlab-language-pack-zh-cn (==4.0.post6)", "matplotlib (==3.8.2)", "nbconvert (>=7.0.0)", "pandas (==2.2.0)", "scipy (==1.12.0)", "vega-datasets (==0.9.0)"] test = ["coverage", "pytest (>=7.0)", "pytest-check-links (>=0.7)", "pytest-console-scripts", "pytest-cov", "pytest-jupyter (>=0.5.3)", "pytest-timeout", "pytest-tornasync", "requests", "requests-cache", "virtualenv"] [[package]] @@ -1757,13 +1809,13 @@ files = [ [[package]] name = "jupyterlab-server" -version = "2.25.2" +version = "2.25.4" description = "A set of server components for JupyterLab and JupyterLab like applications." optional = false python-versions = ">=3.8" files = [ - {file = "jupyterlab_server-2.25.2-py3-none-any.whl", hash = "sha256:5b1798c9cc6a44f65c757de9f97fc06fc3d42535afbf47d2ace5e964ab447aaf"}, - {file = "jupyterlab_server-2.25.2.tar.gz", hash = "sha256:bd0ec7a99ebcedc8bcff939ef86e52c378e44c2707e053fcd81d046ce979ee63"}, + {file = "jupyterlab_server-2.25.4-py3-none-any.whl", hash = "sha256:eb645ecc8f9b24bac5decc7803b6d5363250e16ec5af814e516bc2c54dd88081"}, + {file = "jupyterlab_server-2.25.4.tar.gz", hash = "sha256:2098198e1e82e0db982440f9b5136175d73bea2cd42a6480aa6fd502cb23c4f9"}, ] [package.dependencies] @@ -1779,17 +1831,17 @@ requests = ">=2.31" [package.extras] docs = ["autodoc-traits", "jinja2 (<3.2.0)", "mistune (<4)", "myst-parser", "pydata-sphinx-theme", "sphinx", "sphinx-copybutton", "sphinxcontrib-openapi (>0.8)"] openapi = ["openapi-core (>=0.18.0,<0.19.0)", "ruamel-yaml"] -test = ["hatch", "ipykernel", "openapi-core (>=0.18.0,<0.19.0)", "openapi-spec-validator (>=0.6.0,<0.8.0)", "pytest (>=7.0)", "pytest-console-scripts", "pytest-cov", "pytest-jupyter[server] (>=0.6.2)", "pytest-timeout", "requests-mock", "ruamel-yaml", "sphinxcontrib-spelling", "strict-rfc3339", "werkzeug"] +test = ["hatch", "ipykernel", "openapi-core (>=0.18.0,<0.19.0)", "openapi-spec-validator (>=0.6.0,<0.8.0)", "pytest (>=7.0,<8)", "pytest-console-scripts", "pytest-cov", "pytest-jupyter[server] (>=0.6.2)", "pytest-timeout", "requests-mock", "ruamel-yaml", "sphinxcontrib-spelling", "strict-rfc3339", "werkzeug"] [[package]] name = "jupyterlab-widgets" -version = "3.0.9" +version = "3.0.10" description = "Jupyter interactive widgets for JupyterLab" optional = false python-versions = ">=3.7" files = [ - {file = "jupyterlab_widgets-3.0.9-py3-none-any.whl", hash = "sha256:3cf5bdf5b897bf3bccf1c11873aa4afd776d7430200f765e0686bd352487b58d"}, - {file = "jupyterlab_widgets-3.0.9.tar.gz", hash = "sha256:6005a4e974c7beee84060fdfba341a3218495046de8ae3ec64888e5fe19fdb4c"}, + {file = "jupyterlab_widgets-3.0.10-py3-none-any.whl", hash = "sha256:dd61f3ae7a5a7f80299e14585ce6cf3d6925a96c9103c978eda293197730cb64"}, + {file = "jupyterlab_widgets-3.0.10.tar.gz", hash = "sha256:04f2ac04976727e4f9d0fa91cdc2f1ab860f965e504c29dbd6a65c882c9d04c0"}, ] [[package]] @@ -1881,71 +1933,71 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] [[package]] name = "markupsafe" -version = "2.1.3" +version = "2.1.5" description = "Safely add untrusted strings to HTML/XML markup." optional = false python-versions = ">=3.7" files = [ - {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa"}, - {file = "MarkupSafe-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57"}, - {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f"}, - {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52"}, - {file = "MarkupSafe-2.1.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00"}, - {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6"}, - {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779"}, - {file = "MarkupSafe-2.1.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7"}, - {file = "MarkupSafe-2.1.3-cp310-cp310-win32.whl", hash = "sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431"}, - {file = "MarkupSafe-2.1.3-cp310-cp310-win_amd64.whl", hash = "sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559"}, - {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c"}, - {file = "MarkupSafe-2.1.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575"}, - {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee"}, - {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2"}, - {file = "MarkupSafe-2.1.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9"}, - {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc"}, - {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9"}, - {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, - {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, - {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, - {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, - {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, - {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, - {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e"}, - {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc"}, - {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48"}, - {file = "MarkupSafe-2.1.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155"}, - {file = "MarkupSafe-2.1.3-cp37-cp37m-win32.whl", hash = "sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0"}, - {file = "MarkupSafe-2.1.3-cp37-cp37m-win_amd64.whl", hash = "sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24"}, - {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4"}, - {file = "MarkupSafe-2.1.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0"}, - {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee"}, - {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be"}, - {file = "MarkupSafe-2.1.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e"}, - {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8"}, - {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3"}, - {file = "MarkupSafe-2.1.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d"}, - {file = "MarkupSafe-2.1.3-cp38-cp38-win32.whl", hash = "sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5"}, - {file = "MarkupSafe-2.1.3-cp38-cp38-win_amd64.whl", hash = "sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc"}, - {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198"}, - {file = "MarkupSafe-2.1.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b"}, - {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58"}, - {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e"}, - {file = "MarkupSafe-2.1.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c"}, - {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636"}, - {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea"}, - {file = "MarkupSafe-2.1.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e"}, - {file = "MarkupSafe-2.1.3-cp39-cp39-win32.whl", hash = "sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2"}, - {file = "MarkupSafe-2.1.3-cp39-cp39-win_amd64.whl", hash = "sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba"}, - {file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-win32.whl", hash = "sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-win_amd64.whl", hash = "sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-win32.whl", hash = "sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-win_amd64.whl", hash = "sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-win32.whl", hash = "sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-win_amd64.whl", hash = "sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c8b29db45f8fe46ad280a7294f5c3ec36dbac9491f2d1c17345be8e69cc5928f"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec6a563cff360b50eed26f13adc43e61bc0c04d94b8be985e6fb24b81f6dcfdf"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a549b9c31bec33820e885335b451286e2969a2d9e24879f83fe904a5ce59d70a"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4f11aa001c540f62c6166c7726f71f7573b52c68c31f014c25cc7901deea0b52"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7b2e5a267c855eea6b4283940daa6e88a285f5f2a67f2220203786dfa59b37e9"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:2d2d793e36e230fd32babe143b04cec8a8b3eb8a3122d2aceb4a371e6b09b8df"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ce409136744f6521e39fd8e2a24c53fa18ad67aa5bc7c2cf83645cce5b5c4e50"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-win32.whl", hash = "sha256:4096e9de5c6fdf43fb4f04c26fb114f61ef0bf2e5604b6ee3019d51b69e8c371"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-win_amd64.whl", hash = "sha256:4275d846e41ecefa46e2015117a9f491e57a71ddd59bbead77e904dc02b1bed2"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:656f7526c69fac7f600bd1f400991cc282b417d17539a1b228617081106feb4a"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:97cafb1f3cbcd3fd2b6fbfb99ae11cdb14deea0736fc2b0952ee177f2b813a46"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f3fbcb7ef1f16e48246f704ab79d79da8a46891e2da03f8783a5b6fa41a9532"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa9db3f79de01457b03d4f01b34cf91bc0048eb2c3846ff26f66687c2f6d16ab"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5dedb4db619ba5a2787a94d877bc8ffc0566f92a01c0ef214865e54ecc9ee5e0"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:30b600cf0a7ac9234b2638fbc0fb6158ba5bdcdf46aeb631ead21248b9affbc4"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8dd717634f5a044f860435c1d8c16a270ddf0ef8588d4887037c5028b859b0c3"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-win32.whl", hash = "sha256:daa4ee5a243f0f20d528d939d06670a298dd39b1ad5f8a72a4275124a7819eff"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-win_amd64.whl", hash = "sha256:619bc166c4f2de5caa5a633b8b7326fbe98e0ccbfacabd87268a2b15ff73a029"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-win32.whl", hash = "sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-win_amd64.whl", hash = "sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5"}, + {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"}, ] [[package]] @@ -2022,149 +2074,161 @@ tests = ["pytest (>=4.6)"] [[package]] name = "multidict" -version = "6.0.4" +version = "6.0.5" description = "multidict implementation" optional = false python-versions = ">=3.7" files = [ - {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b1a97283e0c85772d613878028fec909f003993e1007eafa715b24b377cb9b8"}, - {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eeb6dcc05e911516ae3d1f207d4b0520d07f54484c49dfc294d6e7d63b734171"}, - {file = "multidict-6.0.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d6d635d5209b82a3492508cf5b365f3446afb65ae7ebd755e70e18f287b0adf7"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c048099e4c9e9d615545e2001d3d8a4380bd403e1a0578734e0d31703d1b0c0b"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea20853c6dbbb53ed34cb4d080382169b6f4554d394015f1bef35e881bf83547"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:16d232d4e5396c2efbbf4f6d4df89bfa905eb0d4dc5b3549d872ab898451f569"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36c63aaa167f6c6b04ef2c85704e93af16c11d20de1d133e39de6a0e84582a93"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:64bdf1086b6043bf519869678f5f2757f473dee970d7abf6da91ec00acb9cb98"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:43644e38f42e3af682690876cff722d301ac585c5b9e1eacc013b7a3f7b696a0"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7582a1d1030e15422262de9f58711774e02fa80df0d1578995c76214f6954988"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ddff9c4e225a63a5afab9dd15590432c22e8057e1a9a13d28ed128ecf047bbdc"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ee2a1ece51b9b9e7752e742cfb661d2a29e7bcdba2d27e66e28a99f1890e4fa0"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a2e4369eb3d47d2034032a26c7a80fcb21a2cb22e1173d761a162f11e562caa5"}, - {file = "multidict-6.0.4-cp310-cp310-win32.whl", hash = "sha256:574b7eae1ab267e5f8285f0fe881f17efe4b98c39a40858247720935b893bba8"}, - {file = "multidict-6.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:4dcbb0906e38440fa3e325df2359ac6cb043df8e58c965bb45f4e406ecb162cc"}, - {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0dfad7a5a1e39c53ed00d2dd0c2e36aed4650936dc18fd9a1826a5ae1cad6f03"}, - {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:64da238a09d6039e3bd39bb3aee9c21a5e34f28bfa5aa22518581f910ff94af3"}, - {file = "multidict-6.0.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ff959bee35038c4624250473988b24f846cbeb2c6639de3602c073f10410ceba"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:01a3a55bd90018c9c080fbb0b9f4891db37d148a0a18722b42f94694f8b6d4c9"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5cb09abb18c1ea940fb99360ea0396f34d46566f157122c92dfa069d3e0e982"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:666daae833559deb2d609afa4490b85830ab0dfca811a98b70a205621a6109fe"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11bdf3f5e1518b24530b8241529d2050014c884cf18b6fc69c0c2b30ca248710"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d18748f2d30f94f498e852c67d61261c643b349b9d2a581131725595c45ec6c"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:458f37be2d9e4c95e2d8866a851663cbc76e865b78395090786f6cd9b3bbf4f4"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:b1a2eeedcead3a41694130495593a559a668f382eee0727352b9a41e1c45759a"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7d6ae9d593ef8641544d6263c7fa6408cc90370c8cb2bbb65f8d43e5b0351d9c"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5979b5632c3e3534e42ca6ff856bb24b2e3071b37861c2c727ce220d80eee9ed"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dcfe792765fab89c365123c81046ad4103fcabbc4f56d1c1997e6715e8015461"}, - {file = "multidict-6.0.4-cp311-cp311-win32.whl", hash = "sha256:3601a3cece3819534b11d4efc1eb76047488fddd0c85a3948099d5da4d504636"}, - {file = "multidict-6.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:81a4f0b34bd92df3da93315c6a59034df95866014ac08535fc819f043bfd51f0"}, - {file = "multidict-6.0.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:67040058f37a2a51ed8ea8f6b0e6ee5bd78ca67f169ce6122f3e2ec80dfe9b78"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:853888594621e6604c978ce2a0444a1e6e70c8d253ab65ba11657659dcc9100f"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:39ff62e7d0f26c248b15e364517a72932a611a9b75f35b45be078d81bdb86603"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af048912e045a2dc732847d33821a9d84ba553f5c5f028adbd364dd4765092ac"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1e8b901e607795ec06c9e42530788c45ac21ef3aaa11dbd0c69de543bfb79a9"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62501642008a8b9871ddfccbf83e4222cf8ac0d5aeedf73da36153ef2ec222d2"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:99b76c052e9f1bc0721f7541e5e8c05db3941eb9ebe7b8553c625ef88d6eefde"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:509eac6cf09c794aa27bcacfd4d62c885cce62bef7b2c3e8b2e49d365b5003fe"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:21a12c4eb6ddc9952c415f24eef97e3e55ba3af61f67c7bc388dcdec1404a067"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:5cad9430ab3e2e4fa4a2ef4450f548768400a2ac635841bc2a56a2052cdbeb87"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ab55edc2e84460694295f401215f4a58597f8f7c9466faec545093045476327d"}, - {file = "multidict-6.0.4-cp37-cp37m-win32.whl", hash = "sha256:5a4dcf02b908c3b8b17a45fb0f15b695bf117a67b76b7ad18b73cf8e92608775"}, - {file = "multidict-6.0.4-cp37-cp37m-win_amd64.whl", hash = "sha256:6ed5f161328b7df384d71b07317f4d8656434e34591f20552c7bcef27b0ab88e"}, - {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5fc1b16f586f049820c5c5b17bb4ee7583092fa0d1c4e28b5239181ff9532e0c"}, - {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1502e24330eb681bdaa3eb70d6358e818e8e8f908a22a1851dfd4e15bc2f8161"}, - {file = "multidict-6.0.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b692f419760c0e65d060959df05f2a531945af31fda0c8a3b3195d4efd06de11"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45e1ecb0379bfaab5eef059f50115b54571acfbe422a14f668fc8c27ba410e7e"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ddd3915998d93fbcd2566ddf9cf62cdb35c9e093075f862935573d265cf8f65d"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:59d43b61c59d82f2effb39a93c48b845efe23a3852d201ed2d24ba830d0b4cf2"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc8e1d0c705233c5dd0c5e6460fbad7827d5d36f310a0fadfd45cc3029762258"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6aa0418fcc838522256761b3415822626f866758ee0bc6632c9486b179d0b52"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6748717bb10339c4760c1e63da040f5f29f5ed6e59d76daee30305894069a660"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4d1a3d7ef5e96b1c9e92f973e43aa5e5b96c659c9bc3124acbbd81b0b9c8a951"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4372381634485bec7e46718edc71528024fcdc6f835baefe517b34a33c731d60"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:fc35cb4676846ef752816d5be2193a1e8367b4c1397b74a565a9d0389c433a1d"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4b9d9e4e2b37daddb5c23ea33a3417901fa7c7b3dee2d855f63ee67a0b21e5b1"}, - {file = "multidict-6.0.4-cp38-cp38-win32.whl", hash = "sha256:e41b7e2b59679edfa309e8db64fdf22399eec4b0b24694e1b2104fb789207779"}, - {file = "multidict-6.0.4-cp38-cp38-win_amd64.whl", hash = "sha256:d6c254ba6e45d8e72739281ebc46ea5eb5f101234f3ce171f0e9f5cc86991480"}, - {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:16ab77bbeb596e14212e7bab8429f24c1579234a3a462105cda4a66904998664"}, - {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc779e9e6f7fda81b3f9aa58e3a6091d49ad528b11ed19f6621408806204ad35"}, - {file = "multidict-6.0.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ceef517eca3e03c1cceb22030a3e39cb399ac86bff4e426d4fc6ae49052cc60"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:281af09f488903fde97923c7744bb001a9b23b039a909460d0f14edc7bf59706"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:52f2dffc8acaba9a2f27174c41c9e57f60b907bb9f096b36b1a1f3be71c6284d"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b41156839806aecb3641f3208c0dafd3ac7775b9c4c422d82ee2a45c34ba81ca"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5e3fc56f88cc98ef8139255cf8cd63eb2c586531e43310ff859d6bb3a6b51f1"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8316a77808c501004802f9beebde51c9f857054a0c871bd6da8280e718444449"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f70b98cd94886b49d91170ef23ec5c0e8ebb6f242d734ed7ed677b24d50c82cf"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bf6774e60d67a9efe02b3616fee22441d86fab4c6d335f9d2051d19d90a40063"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:e69924bfcdda39b722ef4d9aa762b2dd38e4632b3641b1d9a57ca9cd18f2f83a"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:6b181d8c23da913d4ff585afd1155a0e1194c0b50c54fcfe286f70cdaf2b7176"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:52509b5be062d9eafc8170e53026fbc54cf3b32759a23d07fd935fb04fc22d95"}, - {file = "multidict-6.0.4-cp39-cp39-win32.whl", hash = "sha256:27c523fbfbdfd19c6867af7346332b62b586eed663887392cff78d614f9ec313"}, - {file = "multidict-6.0.4-cp39-cp39-win_amd64.whl", hash = "sha256:33029f5734336aa0d4c0384525da0387ef89148dc7191aae00ca5fb23d7aafc2"}, - {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896ebdcf62683551312c30e20614305f53125750803b614e9e6ce74a96232604"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:411bf8515f3be9813d06004cac41ccf7d1cd46dfe233705933dd163b60e37600"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d147090048129ce3c453f0292e7697d333db95e52616b3793922945804a433c"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:215ed703caf15f578dca76ee6f6b21b7603791ae090fbf1ef9d865571039ade5"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c6390cf87ff6234643428991b7359b5f59cc15155695deb4eda5c777d2b880f"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cc2ad10255f903656017363cd59436f2111443a76f996584d1077e43ee51182"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6939c95381e003f54cd4c5516740faba40cf5ad3eeff460c3ad1d3e0ea2549bf"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:220dd781e3f7af2c2c1053da9fa96d9cf3072ca58f057f4c5adaaa1cab8fc442"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:766c8f7511df26d9f11cd3a8be623e59cca73d44643abab3f8c8c07620524e4a"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fe5d7785250541f7f5019ab9cba2c71169dc7d74d0f45253f8313f436458a4ef"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1c1496e73051918fcd4f58ff2e0f2f3066d1c76a0c6aeffd9b45d53243702cc"}, + {file = "multidict-6.0.5-cp310-cp310-win32.whl", hash = "sha256:7afcdd1fc07befad18ec4523a782cde4e93e0a2bf71239894b8d61ee578c1319"}, + {file = "multidict-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:99f60d34c048c5c2fabc766108c103612344c46e35d4ed9ae0673d33c8fb26e8"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:612d1156111ae11d14afaf3a0669ebf6c170dbb735e510a7438ffe2369a847fd"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7be7047bd08accdb7487737631d25735c9a04327911de89ff1b26b81745bd4e3"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de170c7b4fe6859beb8926e84f7d7d6c693dfe8e27372ce3b76f01c46e489fcf"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04bde7a7b3de05732a4eb39c94574db1ec99abb56162d6c520ad26f83267de29"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85f67aed7bb647f93e7520633d8f51d3cbc6ab96957c71272b286b2f30dc70ed"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425bf820055005bfc8aa9a0b99ccb52cc2f4070153e34b701acc98d201693733"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3eb1ceec286eba8220c26f3b0096cf189aea7057b6e7b7a2e60ed36b373b77f"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7901c05ead4b3fb75113fb1dd33eb1253c6d3ee37ce93305acd9d38e0b5f21a4"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e0e79d91e71b9867c73323a3444724d496c037e578a0e1755ae159ba14f4f3d1"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:29bfeb0dff5cb5fdab2023a7a9947b3b4af63e9c47cae2a10ad58394b517fddc"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e030047e85cbcedbfc073f71836d62dd5dadfbe7531cae27789ff66bc551bd5e"}, + {file = "multidict-6.0.5-cp311-cp311-win32.whl", hash = "sha256:2f4848aa3baa109e6ab81fe2006c77ed4d3cd1e0ac2c1fbddb7b1277c168788c"}, + {file = "multidict-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:2faa5ae9376faba05f630d7e5e6be05be22913782b927b19d12b8145968a85ea"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:51d035609b86722963404f711db441cf7134f1889107fb171a970c9701f92e1e"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cbebcd5bcaf1eaf302617c114aa67569dd3f090dd0ce8ba9e35e9985b41ac35b"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ffc42c922dbfddb4a4c3b438eb056828719f07608af27d163191cb3e3aa6cc5"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceb3b7e6a0135e092de86110c5a74e46bda4bd4fbfeeb3a3bcec79c0f861e450"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79660376075cfd4b2c80f295528aa6beb2058fd289f4c9252f986751a4cd0496"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4428b29611e989719874670fd152b6625500ad6c686d464e99f5aaeeaca175a"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84a5c3a5f7ce6db1f999fb9438f686bc2e09d38143f2d93d8406ed2dd6b9226"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76c0de87358b192de7ea9649beb392f107dcad9ad27276324c24c91774ca5271"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79a6d2ba910adb2cbafc95dad936f8b9386e77c84c35bc0add315b856d7c3abb"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:92d16a3e275e38293623ebf639c471d3e03bb20b8ebb845237e0d3664914caef"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb616be3538599e797a2017cccca78e354c767165e8858ab5116813146041a24"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:14c2976aa9038c2629efa2c148022ed5eb4cb939e15ec7aace7ca932f48f9ba6"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:435a0984199d81ca178b9ae2c26ec3d49692d20ee29bc4c11a2a8d4514c67eda"}, + {file = "multidict-6.0.5-cp312-cp312-win32.whl", hash = "sha256:9fe7b0653ba3d9d65cbe7698cca585bf0f8c83dbbcc710db9c90f478e175f2d5"}, + {file = "multidict-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556"}, + {file = "multidict-6.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:19fe01cea168585ba0f678cad6f58133db2aa14eccaf22f88e4a6dccadfad8b3"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf7a982604375a8d49b6cc1b781c1747f243d91b81035a9b43a2126c04766f5"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:107c0cdefe028703fb5dafe640a409cb146d44a6ae201e55b35a4af8e95457dd"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:403c0911cd5d5791605808b942c88a8155c2592e05332d2bf78f18697a5fa15e"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aeaf541ddbad8311a87dd695ed9642401131ea39ad7bc8cf3ef3967fd093b626"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4972624066095e52b569e02b5ca97dbd7a7ddd4294bf4e7247d52635630dd83"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d946b0a9eb8aaa590df1fe082cee553ceab173e6cb5b03239716338629c50c7a"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b55358304d7a73d7bdf5de62494aaf70bd33015831ffd98bc498b433dfe5b10c"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:a3145cb08d8625b2d3fee1b2d596a8766352979c9bffe5d7833e0503d0f0b5e5"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d65f25da8e248202bd47445cec78e0025c0fe7582b23ec69c3b27a640dd7a8e3"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c9bf56195c6bbd293340ea82eafd0071cb3d450c703d2c93afb89f93b8386ccc"}, + {file = "multidict-6.0.5-cp37-cp37m-win32.whl", hash = "sha256:69db76c09796b313331bb7048229e3bee7928eb62bab5e071e9f7fcc4879caee"}, + {file = "multidict-6.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:fce28b3c8a81b6b36dfac9feb1de115bab619b3c13905b419ec71d03a3fc1423"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:76f067f5121dcecf0d63a67f29080b26c43c71a98b10c701b0677e4a065fbd54"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b82cc8ace10ab5bd93235dfaab2021c70637005e1ac787031f4d1da63d493c1d"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cb241881eefd96b46f89b1a056187ea8e9ba14ab88ba632e68d7a2ecb7aadf7"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8e94e6912639a02ce173341ff62cc1201232ab86b8a8fcc05572741a5dc7d93"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09a892e4a9fb47331da06948690ae38eaa2426de97b4ccbfafbdcbe5c8f37ff8"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55205d03e8a598cfc688c71ca8ea5f66447164efff8869517f175ea632c7cb7b"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37b15024f864916b4951adb95d3a80c9431299080341ab9544ed148091b53f50"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2a1dee728b52b33eebff5072817176c172050d44d67befd681609b4746e1c2e"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:edd08e6f2f1a390bf137080507e44ccc086353c8e98c657e666c017718561b89"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:60d698e8179a42ec85172d12f50b1668254628425a6bd611aba022257cac1386"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:3d25f19500588cbc47dc19081d78131c32637c25804df8414463ec908631e453"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4cc0ef8b962ac7a5e62b9e826bd0cd5040e7d401bc45a6835910ed699037a461"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:eca2e9d0cc5a889850e9bbd68e98314ada174ff6ccd1129500103df7a94a7a44"}, + {file = "multidict-6.0.5-cp38-cp38-win32.whl", hash = "sha256:4a6a4f196f08c58c59e0b8ef8ec441d12aee4125a7d4f4fef000ccb22f8d7241"}, + {file = "multidict-6.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:0275e35209c27a3f7951e1ce7aaf93ce0d163b28948444bec61dd7badc6d3f8c"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e7be68734bd8c9a513f2b0cfd508802d6609da068f40dc57d4e3494cefc92929"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1d9ea7a7e779d7a3561aade7d596649fbecfa5c08a7674b11b423783217933f9"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ea1456df2a27c73ce51120fa2f519f1bea2f4a03a917f4a43c8707cf4cbbae1a"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf590b134eb70629e350691ecca88eac3e3b8b3c86992042fb82e3cb1830d5e1"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5c0631926c4f58e9a5ccce555ad7747d9a9f8b10619621f22f9635f069f6233e"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dce1c6912ab9ff5f179eaf6efe7365c1f425ed690b03341911bf4939ef2f3046"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0868d64af83169e4d4152ec612637a543f7a336e4a307b119e98042e852ad9c"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141b43360bfd3bdd75f15ed811850763555a251e38b2405967f8e25fb43f7d40"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7df704ca8cf4a073334e0427ae2345323613e4df18cc224f647f251e5e75a527"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6214c5a5571802c33f80e6c84713b2c79e024995b9c5897f794b43e714daeec9"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd6c8fca38178e12c00418de737aef1261576bd1b6e8c6134d3e729a4e858b38"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:e02021f87a5b6932fa6ce916ca004c4d441509d33bbdbeca70d05dff5e9d2479"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ebd8d160f91a764652d3e51ce0d2956b38efe37c9231cd82cfc0bed2e40b581c"}, + {file = "multidict-6.0.5-cp39-cp39-win32.whl", hash = "sha256:04da1bb8c8dbadf2a18a452639771951c662c5ad03aefe4884775454be322c9b"}, + {file = "multidict-6.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d6f6d4f185481c9669b9447bf9d9cf3b95a0e9df9d169bbc17e363b7d5487755"}, + {file = "multidict-6.0.5-py3-none-any.whl", hash = "sha256:0d63c74e3d7ab26de115c49bffc92cc77ed23395303d496eae515d4204a625e7"}, + {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"}, ] [[package]] name = "multiprocess" -version = "0.70.15" +version = "0.70.16" description = "better multiprocessing and multithreading in Python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "multiprocess-0.70.15-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:aa36c7ed16f508091438687fe9baa393a7a8e206731d321e443745e743a0d4e5"}, - {file = "multiprocess-0.70.15-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:20e024018c46d0d1602024c613007ac948f9754659e3853b0aa705e83f6931d8"}, - {file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_i686.whl", hash = "sha256:e576062981c91f0fe8a463c3d52506e598dfc51320a8dd8d78b987dfca91c5db"}, - {file = "multiprocess-0.70.15-pp37-pypy37_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:e73f497e6696a0f5433ada2b3d599ae733b87a6e8b008e387c62ac9127add177"}, - {file = "multiprocess-0.70.15-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:73db2e7b32dcc7f9b0f075c2ffa45c90b6729d3f1805f27e88534c8d321a1be5"}, - {file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_i686.whl", hash = "sha256:4271647bd8a49c28ecd6eb56a7fdbd3c212c45529ad5303b40b3c65fc6928e5f"}, - {file = "multiprocess-0.70.15-pp38-pypy38_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:cf981fb998d6ec3208cb14f0cf2e9e80216e834f5d51fd09ebc937c32b960902"}, - {file = "multiprocess-0.70.15-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:18f9f2c7063346d1617bd1684fdcae8d33380ae96b99427260f562e1a1228b67"}, - {file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_i686.whl", hash = "sha256:0eac53214d664c49a34695e5824872db4006b1a465edd7459a251809c3773370"}, - {file = "multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:1a51dd34096db47fb21fa2b839e615b051d51b97af9a67afbcdaa67186b44883"}, - {file = "multiprocess-0.70.15-py310-none-any.whl", hash = "sha256:7dd58e33235e83cf09d625e55cffd7b0f0eede7ee9223cdd666a87624f60c21a"}, - {file = "multiprocess-0.70.15-py311-none-any.whl", hash = "sha256:134f89053d82c9ed3b73edd3a2531eb791e602d4f4156fc92a79259590bd9670"}, - {file = "multiprocess-0.70.15-py37-none-any.whl", hash = "sha256:f7d4a1629bccb433114c3b4885f69eccc200994323c80f6feee73b0edc9199c5"}, - {file = "multiprocess-0.70.15-py38-none-any.whl", hash = "sha256:bee9afba476c91f9ebee7beeee0601face9eff67d822e893f9a893725fbd6316"}, - {file = "multiprocess-0.70.15-py39-none-any.whl", hash = "sha256:3e0953f5d52b4c76f1c973eaf8214554d146f2be5decb48e928e55c7a2d19338"}, - {file = "multiprocess-0.70.15.tar.gz", hash = "sha256:f20eed3036c0ef477b07a4177cf7c1ba520d9a2677870a4f47fe026f0cd6787e"}, + {file = "multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee"}, + {file = "multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec"}, + {file = "multiprocess-0.70.16-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:37b55f71c07e2d741374998c043b9520b626a8dddc8b3129222ca4f1a06ef67a"}, + {file = "multiprocess-0.70.16-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba8c31889abf4511c7308a8c52bb4a30b9d590e7f58523302ba00237702ca054"}, + {file = "multiprocess-0.70.16-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:0dfd078c306e08d46d7a8d06fb120313d87aa43af60d66da43ffff40b44d2f41"}, + {file = "multiprocess-0.70.16-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e7b9d0f307cd9bd50851afaac0dba2cb6c44449efff697df7c7645f7d3f2be3a"}, + {file = "multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02"}, + {file = "multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a"}, + {file = "multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e"}, + {file = "multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435"}, + {file = "multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3"}, + {file = "multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1"}, ] [package.dependencies] -dill = ">=0.3.7" +dill = ">=0.3.8" [[package]] name = "mypy" -version = "1.8.0" +version = "1.9.0" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" files = [ - {file = "mypy-1.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:485a8942f671120f76afffff70f259e1cd0f0cfe08f81c05d8816d958d4577d3"}, - {file = "mypy-1.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:df9824ac11deaf007443e7ed2a4a26bebff98d2bc43c6da21b2b64185da011c4"}, - {file = "mypy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2afecd6354bbfb6e0160f4e4ad9ba6e4e003b767dd80d85516e71f2e955ab50d"}, - {file = "mypy-1.8.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8963b83d53ee733a6e4196954502b33567ad07dfd74851f32be18eb932fb1cb9"}, - {file = "mypy-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:e46f44b54ebddbeedbd3d5b289a893219065ef805d95094d16a0af6630f5d410"}, - {file = "mypy-1.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:855fe27b80375e5c5878492f0729540db47b186509c98dae341254c8f45f42ae"}, - {file = "mypy-1.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4c886c6cce2d070bd7df4ec4a05a13ee20c0aa60cb587e8d1265b6c03cf91da3"}, - {file = "mypy-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d19c413b3c07cbecf1f991e2221746b0d2a9410b59cb3f4fb9557f0365a1a817"}, - {file = "mypy-1.8.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9261ed810972061388918c83c3f5cd46079d875026ba97380f3e3978a72f503d"}, - {file = "mypy-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:51720c776d148bad2372ca21ca29256ed483aa9a4cdefefcef49006dff2a6835"}, - {file = "mypy-1.8.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:52825b01f5c4c1c4eb0db253ec09c7aa17e1a7304d247c48b6f3599ef40db8bd"}, - {file = "mypy-1.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f5ac9a4eeb1ec0f1ccdc6f326bcdb464de5f80eb07fb38b5ddd7b0de6bc61e55"}, - {file = "mypy-1.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afe3fe972c645b4632c563d3f3eff1cdca2fa058f730df2b93a35e3b0c538218"}, - {file = "mypy-1.8.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:42c6680d256ab35637ef88891c6bd02514ccb7e1122133ac96055ff458f93fc3"}, - {file = "mypy-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:720a5ca70e136b675af3af63db533c1c8c9181314d207568bbe79051f122669e"}, - {file = "mypy-1.8.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:028cf9f2cae89e202d7b6593cd98db6759379f17a319b5faf4f9978d7084cdc6"}, - {file = "mypy-1.8.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4e6d97288757e1ddba10dd9549ac27982e3e74a49d8d0179fc14d4365c7add66"}, - {file = "mypy-1.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f1478736fcebb90f97e40aff11a5f253af890c845ee0c850fe80aa060a267c6"}, - {file = "mypy-1.8.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:42419861b43e6962a649068a61f4a4839205a3ef525b858377a960b9e2de6e0d"}, - {file = "mypy-1.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:2b5b6c721bd4aabaadead3a5e6fa85c11c6c795e0c81a7215776ef8afc66de02"}, - {file = "mypy-1.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5c1538c38584029352878a0466f03a8ee7547d7bd9f641f57a0f3017a7c905b8"}, - {file = "mypy-1.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ef4be7baf08a203170f29e89d79064463b7fc7a0908b9d0d5114e8009c3a259"}, - {file = "mypy-1.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7178def594014aa6c35a8ff411cf37d682f428b3b5617ca79029d8ae72f5402b"}, - {file = "mypy-1.8.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ab3c84fa13c04aeeeabb2a7f67a25ef5d77ac9d6486ff33ded762ef353aa5592"}, - {file = "mypy-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:99b00bc72855812a60d253420d8a2eae839b0afa4938f09f4d2aa9bb4654263a"}, - {file = "mypy-1.8.0-py3-none-any.whl", hash = "sha256:538fd81bb5e430cc1381a443971c0475582ff9f434c16cd46d2c66763ce85d9d"}, - {file = "mypy-1.8.0.tar.gz", hash = "sha256:6ff8b244d7085a0b425b56d327b480c3b29cafbd2eff27316a004f9a7391ae07"}, + {file = "mypy-1.9.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f8a67616990062232ee4c3952f41c779afac41405806042a8126fe96e098419f"}, + {file = "mypy-1.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d357423fa57a489e8c47b7c85dfb96698caba13d66e086b412298a1a0ea3b0ed"}, + {file = "mypy-1.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49c87c15aed320de9b438ae7b00c1ac91cd393c1b854c2ce538e2a72d55df150"}, + {file = "mypy-1.9.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:48533cdd345c3c2e5ef48ba3b0d3880b257b423e7995dada04248725c6f77374"}, + {file = "mypy-1.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:4d3dbd346cfec7cb98e6cbb6e0f3c23618af826316188d587d1c1bc34f0ede03"}, + {file = "mypy-1.9.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:653265f9a2784db65bfca694d1edd23093ce49740b2244cde583aeb134c008f3"}, + {file = "mypy-1.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3a3c007ff3ee90f69cf0a15cbcdf0995749569b86b6d2f327af01fd1b8aee9dc"}, + {file = "mypy-1.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2418488264eb41f69cc64a69a745fad4a8f86649af4b1041a4c64ee61fc61129"}, + {file = "mypy-1.9.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:68edad3dc7d70f2f17ae4c6c1b9471a56138ca22722487eebacfd1eb5321d612"}, + {file = "mypy-1.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:85ca5fcc24f0b4aeedc1d02f93707bccc04733f21d41c88334c5482219b1ccb3"}, + {file = "mypy-1.9.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aceb1db093b04db5cd390821464504111b8ec3e351eb85afd1433490163d60cd"}, + {file = "mypy-1.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0235391f1c6f6ce487b23b9dbd1327b4ec33bb93934aa986efe8a9563d9349e6"}, + {file = "mypy-1.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4d5ddc13421ba3e2e082a6c2d74c2ddb3979c39b582dacd53dd5d9431237185"}, + {file = "mypy-1.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:190da1ee69b427d7efa8aa0d5e5ccd67a4fb04038c380237a0d96829cb157913"}, + {file = "mypy-1.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:fe28657de3bfec596bbeef01cb219833ad9d38dd5393fc649f4b366840baefe6"}, + {file = "mypy-1.9.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e54396d70be04b34f31d2edf3362c1edd023246c82f1730bbf8768c28db5361b"}, + {file = "mypy-1.9.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5e6061f44f2313b94f920e91b204ec600982961e07a17e0f6cd83371cb23f5c2"}, + {file = "mypy-1.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81a10926e5473c5fc3da8abb04119a1f5811a236dc3a38d92015cb1e6ba4cb9e"}, + {file = "mypy-1.9.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b685154e22e4e9199fc95f298661deea28aaede5ae16ccc8cbb1045e716b3e04"}, + {file = "mypy-1.9.0-cp38-cp38-win_amd64.whl", hash = "sha256:5d741d3fc7c4da608764073089e5f58ef6352bedc223ff58f2f038c2c4698a89"}, + {file = "mypy-1.9.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:587ce887f75dd9700252a3abbc9c97bbe165a4a630597845c61279cf32dfbf02"}, + {file = "mypy-1.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f88566144752999351725ac623471661c9d1cd8caa0134ff98cceeea181789f4"}, + {file = "mypy-1.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61758fabd58ce4b0720ae1e2fea5cfd4431591d6d590b197775329264f86311d"}, + {file = "mypy-1.9.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e49499be624dead83927e70c756970a0bc8240e9f769389cdf5714b0784ca6bf"}, + {file = "mypy-1.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:571741dc4194b4f82d344b15e8837e8c5fcc462d66d076748142327626a1b6e9"}, + {file = "mypy-1.9.0-py3-none-any.whl", hash = "sha256:a260627a570559181a9ea5de61ac6297aa5af202f06fd7ab093ce74e7181e43e"}, + {file = "mypy-1.9.0.tar.gz", hash = "sha256:3cc5da0127e6a478cddd906068496a97a7618a21ce9b54bde5bf7e539c7af974"}, ] [package.dependencies] @@ -2217,13 +2281,13 @@ testing-docutils = ["pygments", "pytest (>=7,<8)", "pytest-param-files (>=0.3.4, [[package]] name = "nbclient" -version = "0.9.0" +version = "0.10.0" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." optional = false python-versions = ">=3.8.0" files = [ - {file = "nbclient-0.9.0-py3-none-any.whl", hash = "sha256:a3a1ddfb34d4a9d17fc744d655962714a866639acd30130e9be84191cd97cd15"}, - {file = "nbclient-0.9.0.tar.gz", hash = "sha256:4b28c207877cf33ef3a9838cdc7a54c5ceff981194a82eac59d558f05487295e"}, + {file = "nbclient-0.10.0-py3-none-any.whl", hash = "sha256:f13e3529332a1f1f81d82a53210322476a168bb7090a0289c795fe9cc11c9d3f"}, + {file = "nbclient-0.10.0.tar.gz", hash = "sha256:4b3f1b7dba531e498449c4db4f53da339c91d449dc11e9af3a43b4eb5c5abb09"}, ] [package.dependencies] @@ -2235,17 +2299,17 @@ traitlets = ">=5.4" [package.extras] dev = ["pre-commit"] docs = ["autodoc-traits", "mock", "moto", "myst-parser", "nbclient[test]", "sphinx (>=1.7)", "sphinx-book-theme", "sphinxcontrib-spelling"] -test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>=7.0.0)", "pytest (>=7.0)", "pytest-asyncio", "pytest-cov (>=4.0)", "testpath", "xmltodict"] +test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>=7.0.0)", "pytest (>=7.0,<8)", "pytest-asyncio", "pytest-cov (>=4.0)", "testpath", "xmltodict"] [[package]] name = "nbconvert" -version = "7.14.0" -description = "Converting Jupyter Notebooks" +version = "7.16.3" +description = "Converting Jupyter Notebooks (.ipynb files) to other formats. Output formats include asciidoc, html, latex, markdown, pdf, py, rst, script. nbconvert can be used both as a Python library (`import nbconvert`) or as a command line tool (invoked as `jupyter nbconvert ...`)." optional = false python-versions = ">=3.8" files = [ - {file = "nbconvert-7.14.0-py3-none-any.whl", hash = "sha256:483dde47facdaa4875903d651305ad53cd76e2255ae3c61efe412a95f2d22a24"}, - {file = "nbconvert-7.14.0.tar.gz", hash = "sha256:92b9a44b63e5a7fb4f6fa0ef41261e35c16925046ccd1c04a5c8099bf100476e"}, + {file = "nbconvert-7.16.3-py3-none-any.whl", hash = "sha256:ddeff14beeeedf3dd0bc506623e41e4507e551736de59df69a91f86700292b3b"}, + {file = "nbconvert-7.16.3.tar.gz", hash = "sha256:a6733b78ce3d47c3f85e504998495b07e6ea9cf9bf6ec1c98dda63ec6ad19142"}, ] [package.dependencies] @@ -2272,18 +2336,18 @@ docs = ["ipykernel", "ipython", "myst-parser", "nbsphinx (>=0.2.12)", "pydata-sp qtpdf = ["nbconvert[qtpng]"] qtpng = ["pyqtwebengine (>=5.15)"] serve = ["tornado (>=6.1)"] -test = ["flaky", "ipykernel", "ipywidgets (>=7.5)", "pytest"] +test = ["flaky", "ipykernel", "ipywidgets (>=7.5)", "pytest (>=7)"] webpdf = ["playwright"] [[package]] name = "nbformat" -version = "5.9.2" +version = "5.10.3" description = "The Jupyter Notebook format" optional = false python-versions = ">=3.8" files = [ - {file = "nbformat-5.9.2-py3-none-any.whl", hash = "sha256:1c5172d786a41b82bcfd0c23f9e6b6f072e8fb49c39250219e4acfff1efe89e9"}, - {file = "nbformat-5.9.2.tar.gz", hash = "sha256:5f98b5ba1997dff175e77e0c17d5c10a96eaed2cbd1de3533d1fc35d5e111192"}, + {file = "nbformat-5.10.3-py3-none-any.whl", hash = "sha256:d9476ca28676799af85385f409b49d95e199951477a159a576ef2a675151e5e8"}, + {file = "nbformat-5.10.3.tar.gz", hash = "sha256:60ed5e910ef7c6264b87d644f276b1b49e24011930deef54605188ddeb211685"}, ] [package.dependencies] @@ -2335,13 +2399,13 @@ pytest = ">=2.8" [[package]] name = "nest-asyncio" -version = "1.5.8" +version = "1.6.0" description = "Patch asyncio to allow nested event loops" optional = false python-versions = ">=3.5" files = [ - {file = "nest_asyncio-1.5.8-py3-none-any.whl", hash = "sha256:accda7a339a70599cb08f9dd09a67e0c2ef8d8d6f4c07f96ab203f2ae254e48d"}, - {file = "nest_asyncio-1.5.8.tar.gz", hash = "sha256:25aa2ca0d2a5b5531956b9e273b45cf664cae2b145101d73b86b199978d48fdb"}, + {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"}, + {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, ] [[package]] @@ -2364,18 +2428,18 @@ test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] [[package]] name = "notebook" -version = "7.0.6" +version = "7.1.2" description = "Jupyter Notebook - A web-based notebook environment for interactive computing" optional = false python-versions = ">=3.8" files = [ - {file = "notebook-7.0.6-py3-none-any.whl", hash = "sha256:0fe8f67102fea3744fedf652e4c15339390902ca70c5a31c4f547fa23da697cc"}, - {file = "notebook-7.0.6.tar.gz", hash = "sha256:ec6113b06529019f7f287819af06c97a2baf7a95ac21a8f6e32192898e9f9a58"}, + {file = "notebook-7.1.2-py3-none-any.whl", hash = "sha256:fc6c24b9aef18d0cd57157c9c47e95833b9b0bdc599652639acf0bdb61dc7d5f"}, + {file = "notebook-7.1.2.tar.gz", hash = "sha256:efc2c80043909e0faa17fce9e9b37c059c03af0ec99a4d4db84cb21d9d2e936a"}, ] [package.dependencies] jupyter-server = ">=2.4.0,<3" -jupyterlab = ">=4.0.2,<5" +jupyterlab = ">=4.1.1,<4.2" jupyterlab-server = ">=2.22.1,<3" notebook-shim = ">=0.2,<0.3" tornado = ">=6.2.0" @@ -2387,13 +2451,13 @@ test = ["importlib-resources (>=5.0)", "ipykernel", "jupyter-server[test] (>=2.4 [[package]] name = "notebook-shim" -version = "0.2.3" +version = "0.2.4" description = "A shim layer for notebook traits and config" optional = false python-versions = ">=3.7" files = [ - {file = "notebook_shim-0.2.3-py3-none-any.whl", hash = "sha256:a83496a43341c1674b093bfcebf0fe8e74cbe7eda5fd2bbc56f8e39e1486c0c7"}, - {file = "notebook_shim-0.2.3.tar.gz", hash = "sha256:f69388ac283ae008cd506dda10d0288b09a017d822d5e8c7129a152cbd3ce7e9"}, + {file = "notebook_shim-0.2.4-py3-none-any.whl", hash = "sha256:411a5be4e9dc882a074ccbcae671eda64cceb068767e9a3419096986560e1cef"}, + {file = "notebook_shim-0.2.4.tar.gz", hash = "sha256:b4b2cfa1b65d98307ca24361f5b30fe785b53c3fd07b7a47e89acb5e6ac638cb"}, ] [package.dependencies] @@ -2441,47 +2505,47 @@ files = [ [[package]] name = "numpy" -version = "1.26.3" +version = "1.26.4" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.9" files = [ - {file = "numpy-1.26.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:806dd64230dbbfaca8a27faa64e2f414bf1c6622ab78cc4264f7f5f028fee3bf"}, - {file = "numpy-1.26.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02f98011ba4ab17f46f80f7f8f1c291ee7d855fcef0a5a98db80767a468c85cd"}, - {file = "numpy-1.26.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d45b3ec2faed4baca41c76617fcdcfa4f684ff7a151ce6fc78ad3b6e85af0a6"}, - {file = "numpy-1.26.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdd2b45bf079d9ad90377048e2747a0c82351989a2165821f0c96831b4a2a54b"}, - {file = "numpy-1.26.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:211ddd1e94817ed2d175b60b6374120244a4dd2287f4ece45d49228b4d529178"}, - {file = "numpy-1.26.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b1240f767f69d7c4c8a29adde2310b871153df9b26b5cb2b54a561ac85146485"}, - {file = "numpy-1.26.3-cp310-cp310-win32.whl", hash = "sha256:21a9484e75ad018974a2fdaa216524d64ed4212e418e0a551a2d83403b0531d3"}, - {file = "numpy-1.26.3-cp310-cp310-win_amd64.whl", hash = "sha256:9e1591f6ae98bcfac2a4bbf9221c0b92ab49762228f38287f6eeb5f3f55905ce"}, - {file = "numpy-1.26.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b831295e5472954104ecb46cd98c08b98b49c69fdb7040483aff799a755a7374"}, - {file = "numpy-1.26.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9e87562b91f68dd8b1c39149d0323b42e0082db7ddb8e934ab4c292094d575d6"}, - {file = "numpy-1.26.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c66d6fec467e8c0f975818c1796d25c53521124b7cfb760114be0abad53a0a2"}, - {file = "numpy-1.26.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f25e2811a9c932e43943a2615e65fc487a0b6b49218899e62e426e7f0a57eeda"}, - {file = "numpy-1.26.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:af36e0aa45e25c9f57bf684b1175e59ea05d9a7d3e8e87b7ae1a1da246f2767e"}, - {file = "numpy-1.26.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:51c7f1b344f302067b02e0f5b5d2daa9ed4a721cf49f070280ac202738ea7f00"}, - {file = "numpy-1.26.3-cp311-cp311-win32.whl", hash = "sha256:7ca4f24341df071877849eb2034948459ce3a07915c2734f1abb4018d9c49d7b"}, - {file = "numpy-1.26.3-cp311-cp311-win_amd64.whl", hash = "sha256:39763aee6dfdd4878032361b30b2b12593fb445ddb66bbac802e2113eb8a6ac4"}, - {file = "numpy-1.26.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a7081fd19a6d573e1a05e600c82a1c421011db7935ed0d5c483e9dd96b99cf13"}, - {file = "numpy-1.26.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:12c70ac274b32bc00c7f61b515126c9205323703abb99cd41836e8125ea0043e"}, - {file = "numpy-1.26.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f784e13e598e9594750b2ef6729bcd5a47f6cfe4a12cca13def35e06d8163e3"}, - {file = "numpy-1.26.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f24750ef94d56ce6e33e4019a8a4d68cfdb1ef661a52cdaee628a56d2437419"}, - {file = "numpy-1.26.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:77810ef29e0fb1d289d225cabb9ee6cf4d11978a00bb99f7f8ec2132a84e0166"}, - {file = "numpy-1.26.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8ed07a90f5450d99dad60d3799f9c03c6566709bd53b497eb9ccad9a55867f36"}, - {file = "numpy-1.26.3-cp312-cp312-win32.whl", hash = "sha256:f73497e8c38295aaa4741bdfa4fda1a5aedda5473074369eca10626835445511"}, - {file = "numpy-1.26.3-cp312-cp312-win_amd64.whl", hash = "sha256:da4b0c6c699a0ad73c810736303f7fbae483bcb012e38d7eb06a5e3b432c981b"}, - {file = "numpy-1.26.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1666f634cb3c80ccbd77ec97bc17337718f56d6658acf5d3b906ca03e90ce87f"}, - {file = "numpy-1.26.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:18c3319a7d39b2c6a9e3bb75aab2304ab79a811ac0168a671a62e6346c29b03f"}, - {file = "numpy-1.26.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b7e807d6888da0db6e7e75838444d62495e2b588b99e90dd80c3459594e857b"}, - {file = "numpy-1.26.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4d362e17bcb0011738c2d83e0a65ea8ce627057b2fdda37678f4374a382a137"}, - {file = "numpy-1.26.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b8c275f0ae90069496068c714387b4a0eba5d531aace269559ff2b43655edd58"}, - {file = "numpy-1.26.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:cc0743f0302b94f397a4a65a660d4cd24267439eb16493fb3caad2e4389bccbb"}, - {file = "numpy-1.26.3-cp39-cp39-win32.whl", hash = "sha256:9bc6d1a7f8cedd519c4b7b1156d98e051b726bf160715b769106661d567b3f03"}, - {file = "numpy-1.26.3-cp39-cp39-win_amd64.whl", hash = "sha256:867e3644e208c8922a3be26fc6bbf112a035f50f0a86497f98f228c50c607bb2"}, - {file = "numpy-1.26.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3c67423b3703f8fbd90f5adaa37f85b5794d3366948efe9a5190a5f3a83fc34e"}, - {file = "numpy-1.26.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46f47ee566d98849323f01b349d58f2557f02167ee301e5e28809a8c0e27a2d0"}, - {file = "numpy-1.26.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a8474703bffc65ca15853d5fd4d06b18138ae90c17c8d12169968e998e448bb5"}, - {file = "numpy-1.26.3.tar.gz", hash = "sha256:697df43e2b6310ecc9d95f05d5ef20eacc09c7c4ecc9da3f235d39e71b7da1e4"}, + {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"}, + {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"}, + {file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"}, + {file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"}, + {file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"}, + {file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"}, + {file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"}, + {file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"}, + {file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"}, + {file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"}, + {file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"}, + {file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"}, + {file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"}, + {file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"}, + {file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"}, + {file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"}, + {file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"}, + {file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"}, + {file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"}, + {file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"}, + {file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"}, + {file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"}, + {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, ] [[package]] @@ -2595,23 +2659,24 @@ nvidia-nvjitlink-cu12 = "*" [[package]] name = "nvidia-nccl-cu12" -version = "2.18.1" +version = "2.19.3" description = "NVIDIA Collective Communication Library (NCCL) Runtime" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:1a6c4acefcbebfa6de320f412bf7866de856e786e0462326ba1bac40de0b5e71"}, + {file = "nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl", hash = "sha256:a9734707a2c96443331c1e48c717024aa6678a0e2a4cb66b2c364d18cee6b48d"}, ] [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.3.101" +version = "12.4.99" description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"}, - {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"}, + {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_aarch64.whl", hash = "sha256:75d6498c96d9adb9435f2bbdbddb479805ddfb97b5c1b32395c694185c20ca57"}, + {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c6428836d20fe7e327191c175791d38570e10762edc588fb46749217cd444c74"}, + {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-win_amd64.whl", hash = "sha256:991905ffa2144cb603d8ca7962d75c35334ae82bf92820b6ba78157277da1ad2"}, ] [[package]] @@ -2627,24 +2692,24 @@ files = [ [[package]] name = "overrides" -version = "7.4.0" +version = "7.7.0" description = "A decorator to automatically detect mismatch when overriding a method." optional = false python-versions = ">=3.6" files = [ - {file = "overrides-7.4.0-py3-none-any.whl", hash = "sha256:3ad24583f86d6d7a49049695efe9933e67ba62f0c7625d53c59fa832ce4b8b7d"}, - {file = "overrides-7.4.0.tar.gz", hash = "sha256:9502a3cca51f4fac40b5feca985b6703a5c1f6ad815588a7ca9e285b9dca6757"}, + {file = "overrides-7.7.0-py3-none-any.whl", hash = "sha256:c7ed9d062f78b8e4c1a7b70bd8796b35ead4d9f510227ef9c5dc7626c60d7e49"}, + {file = "overrides-7.7.0.tar.gz", hash = "sha256:55158fa3d93b98cc75299b1e67078ad9003ca27945c76162c1c0766d6f91820a"}, ] [[package]] name = "packaging" -version = "23.2" +version = "24.0" description = "Core utilities for Python packages" optional = false python-versions = ">=3.7" files = [ - {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, - {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, + {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"}, + {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, ] [[package]] @@ -2684,8 +2749,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -2730,13 +2795,13 @@ ply = "*" [[package]] name = "pandocfilters" -version = "1.5.0" +version = "1.5.1" description = "Utilities for writing pandoc filters in python" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ - {file = "pandocfilters-1.5.0-py2.py3-none-any.whl", hash = "sha256:33aae3f25fd1a026079f5d27bdd52496f0e0803b3469282162bafdcbdf6ef14f"}, - {file = "pandocfilters-1.5.0.tar.gz", hash = "sha256:0b679503337d233b4339a817bfc8c50064e2eff681314376a47cb582305a7a38"}, + {file = "pandocfilters-1.5.1-py2.py3-none-any.whl", hash = "sha256:93be382804a9cdb0a7267585f157e5d1731bbe5545a85b268d6f5fe6232de2bc"}, + {file = "pandocfilters-1.5.1.tar.gz", hash = "sha256:002b4a555ee4ebc03f8b66307e287fa492e4a77b4ea14d3f934328297bb4939e"}, ] [[package]] @@ -2803,28 +2868,28 @@ files = [ [[package]] name = "platformdirs" -version = "4.1.0" +version = "4.2.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." optional = false python-versions = ">=3.8" files = [ - {file = "platformdirs-4.1.0-py3-none-any.whl", hash = "sha256:11c8f37bcca40db96d8144522d925583bdb7a31f7b0e37e3ed4318400a8e2380"}, - {file = "platformdirs-4.1.0.tar.gz", hash = "sha256:906d548203468492d432bcb294d4bc2fff751bf84971fbb2c10918cc206ee420"}, + {file = "platformdirs-4.2.0-py3-none-any.whl", hash = "sha256:0614df2a2f37e1a662acbd8e2b25b92ccf8632929bc6d43467e17fe89c75e068"}, + {file = "platformdirs-4.2.0.tar.gz", hash = "sha256:ef0cc731df711022c174543cb70a9b5bd22e5a9337c8624ef2c2ceb8ddad8768"}, ] [package.extras] -docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"] -test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"] +docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] [[package]] name = "plotly" -version = "5.18.0" +version = "5.20.0" description = "An open-source, interactive data visualization library for Python" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "plotly-5.18.0-py3-none-any.whl", hash = "sha256:23aa8ea2f4fb364a20d34ad38235524bd9d691bf5299e800bca608c31e8db8de"}, - {file = "plotly-5.18.0.tar.gz", hash = "sha256:360a31e6fbb49d12b007036eb6929521343d6bee2236f8459915821baefa2cbb"}, + {file = "plotly-5.20.0-py3-none-any.whl", hash = "sha256:837a9c8aa90f2c0a2f0d747b82544d014dc2a2bdde967b5bb1da25b53932d1a9"}, + {file = "plotly-5.20.0.tar.gz", hash = "sha256:bf901c805d22032cfa534b2ff7c5aa6b0659e037f19ec1e0cca7f585918b5c89"}, ] [package.dependencies] @@ -2833,13 +2898,13 @@ tenacity = ">=6.2.0" [[package]] name = "pluggy" -version = "1.3.0" +version = "1.4.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" files = [ - {file = "pluggy-1.3.0-py3-none-any.whl", hash = "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7"}, - {file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"}, + {file = "pluggy-1.4.0-py3-none-any.whl", hash = "sha256:7db9f7b503d67d1c5b95f59773ebb58a8c1c288129a88665838012cfb07b8981"}, + {file = "pluggy-1.4.0.tar.gz", hash = "sha256:8c85c2876142a764e5b7548e7d9a0e0ddb46f5185161049a79b7e974454223be"}, ] [package.extras] @@ -2892,13 +2957,13 @@ six = ">=1.5.2" [[package]] name = "prometheus-client" -version = "0.19.0" +version = "0.20.0" description = "Python client for the Prometheus monitoring system." optional = false python-versions = ">=3.8" files = [ - {file = "prometheus_client-0.19.0-py3-none-any.whl", hash = "sha256:c88b1e6ecf6b41cd8fb5731c7ae919bf66df6ec6fafa555cd6c0e16ca169ae92"}, - {file = "prometheus_client-0.19.0.tar.gz", hash = "sha256:4585b0d1223148c27a225b10dbec5ae9bc4c81a99a3fa80774fa6209935324e1"}, + {file = "prometheus_client-0.20.0-py3-none-any.whl", hash = "sha256:cde524a85bce83ca359cc837f28b8c0db5cac7aa653a588fd7e84ba061c329e7"}, + {file = "prometheus_client-0.20.0.tar.gz", hash = "sha256:287629d00b147a32dcb2be0b9df905da599b2d82f80377083ec8463309a4bb89"}, ] [package.extras] @@ -2920,47 +2985,47 @@ wcwidth = "*" [[package]] name = "protobuf" -version = "4.25.1" +version = "4.25.3" description = "" optional = false python-versions = ">=3.8" files = [ - {file = "protobuf-4.25.1-cp310-abi3-win32.whl", hash = "sha256:193f50a6ab78a970c9b4f148e7c750cfde64f59815e86f686c22e26b4fe01ce7"}, - {file = "protobuf-4.25.1-cp310-abi3-win_amd64.whl", hash = "sha256:3497c1af9f2526962f09329fd61a36566305e6c72da2590ae0d7d1322818843b"}, - {file = "protobuf-4.25.1-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:0bf384e75b92c42830c0a679b0cd4d6e2b36ae0cf3dbb1e1dfdda48a244f4bcd"}, - {file = "protobuf-4.25.1-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:0f881b589ff449bf0b931a711926e9ddaad3b35089cc039ce1af50b21a4ae8cb"}, - {file = "protobuf-4.25.1-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:ca37bf6a6d0046272c152eea90d2e4ef34593aaa32e8873fc14c16440f22d4b7"}, - {file = "protobuf-4.25.1-cp38-cp38-win32.whl", hash = "sha256:abc0525ae2689a8000837729eef7883b9391cd6aa7950249dcf5a4ede230d5dd"}, - {file = "protobuf-4.25.1-cp38-cp38-win_amd64.whl", hash = "sha256:1484f9e692091450e7edf418c939e15bfc8fc68856e36ce399aed6889dae8bb0"}, - {file = "protobuf-4.25.1-cp39-cp39-win32.whl", hash = "sha256:8bdbeaddaac52d15c6dce38c71b03038ef7772b977847eb6d374fc86636fa510"}, - {file = "protobuf-4.25.1-cp39-cp39-win_amd64.whl", hash = "sha256:becc576b7e6b553d22cbdf418686ee4daa443d7217999125c045ad56322dda10"}, - {file = "protobuf-4.25.1-py3-none-any.whl", hash = "sha256:a19731d5e83ae4737bb2a089605e636077ac001d18781b3cf489b9546c7c80d6"}, - {file = "protobuf-4.25.1.tar.gz", hash = "sha256:57d65074b4f5baa4ab5da1605c02be90ac20c8b40fb137d6a8df9f416b0d0ce2"}, + {file = "protobuf-4.25.3-cp310-abi3-win32.whl", hash = "sha256:d4198877797a83cbfe9bffa3803602bbe1625dc30d8a097365dbc762e5790faa"}, + {file = "protobuf-4.25.3-cp310-abi3-win_amd64.whl", hash = "sha256:209ba4cc916bab46f64e56b85b090607a676f66b473e6b762e6f1d9d591eb2e8"}, + {file = "protobuf-4.25.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:f1279ab38ecbfae7e456a108c5c0681e4956d5b1090027c1de0f934dfdb4b35c"}, + {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:e7cb0ae90dd83727f0c0718634ed56837bfeeee29a5f82a7514c03ee1364c019"}, + {file = "protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:7c8daa26095f82482307bc717364e7c13f4f1c99659be82890dcfc215194554d"}, + {file = "protobuf-4.25.3-cp38-cp38-win32.whl", hash = "sha256:f4f118245c4a087776e0a8408be33cf09f6c547442c00395fbfb116fac2f8ac2"}, + {file = "protobuf-4.25.3-cp38-cp38-win_amd64.whl", hash = "sha256:c053062984e61144385022e53678fbded7aea14ebb3e0305ae3592fb219ccfa4"}, + {file = "protobuf-4.25.3-cp39-cp39-win32.whl", hash = "sha256:19b270aeaa0099f16d3ca02628546b8baefe2955bbe23224aaf856134eccf1e4"}, + {file = "protobuf-4.25.3-cp39-cp39-win_amd64.whl", hash = "sha256:e3c97a1555fd6388f857770ff8b9703083de6bf1f9274a002a332d65fbb56c8c"}, + {file = "protobuf-4.25.3-py3-none-any.whl", hash = "sha256:f0700d54bcf45424477e46a9f0944155b46fb0639d69728739c0e47bab83f2b9"}, + {file = "protobuf-4.25.3.tar.gz", hash = "sha256:25b5d0b42fd000320bd7830b349e3b696435f3b329810427a6bcce6a5492cc5c"}, ] [[package]] name = "psutil" -version = "5.9.7" +version = "5.9.8" description = "Cross-platform lib for process and system monitoring in Python." optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ - {file = "psutil-5.9.7-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:0bd41bf2d1463dfa535942b2a8f0e958acf6607ac0be52265ab31f7923bcd5e6"}, - {file = "psutil-5.9.7-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:5794944462509e49d4d458f4dbfb92c47539e7d8d15c796f141f474010084056"}, - {file = "psutil-5.9.7-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:fe361f743cb3389b8efda21980d93eb55c1f1e3898269bc9a2a1d0bb7b1f6508"}, - {file = "psutil-5.9.7-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:e469990e28f1ad738f65a42dcfc17adaed9d0f325d55047593cb9033a0ab63df"}, - {file = "psutil-5.9.7-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:3c4747a3e2ead1589e647e64aad601981f01b68f9398ddf94d01e3dc0d1e57c7"}, - {file = "psutil-5.9.7-cp27-none-win32.whl", hash = "sha256:1d4bc4a0148fdd7fd8f38e0498639ae128e64538faa507df25a20f8f7fb2341c"}, - {file = "psutil-5.9.7-cp27-none-win_amd64.whl", hash = "sha256:4c03362e280d06bbbfcd52f29acd79c733e0af33d707c54255d21029b8b32ba6"}, - {file = "psutil-5.9.7-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ea36cc62e69a13ec52b2f625c27527f6e4479bca2b340b7a452af55b34fcbe2e"}, - {file = "psutil-5.9.7-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1132704b876e58d277168cd729d64750633d5ff0183acf5b3c986b8466cd0284"}, - {file = "psutil-5.9.7-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe8b7f07948f1304497ce4f4684881250cd859b16d06a1dc4d7941eeb6233bfe"}, - {file = "psutil-5.9.7-cp36-cp36m-win32.whl", hash = "sha256:b27f8fdb190c8c03914f908a4555159327d7481dac2f01008d483137ef3311a9"}, - {file = "psutil-5.9.7-cp36-cp36m-win_amd64.whl", hash = "sha256:44969859757f4d8f2a9bd5b76eba8c3099a2c8cf3992ff62144061e39ba8568e"}, - {file = "psutil-5.9.7-cp37-abi3-win32.whl", hash = "sha256:c727ca5a9b2dd5193b8644b9f0c883d54f1248310023b5ad3e92036c5e2ada68"}, - {file = "psutil-5.9.7-cp37-abi3-win_amd64.whl", hash = "sha256:f37f87e4d73b79e6c5e749440c3113b81d1ee7d26f21c19c47371ddea834f414"}, - {file = "psutil-5.9.7-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:032f4f2c909818c86cea4fe2cc407f1c0f0cde8e6c6d702b28b8ce0c0d143340"}, - {file = "psutil-5.9.7.tar.gz", hash = "sha256:3f02134e82cfb5d089fddf20bb2e03fd5cd52395321d1c8458a9e58500ff417c"}, + {file = "psutil-5.9.8-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:26bd09967ae00920df88e0352a91cff1a78f8d69b3ecabbfe733610c0af486c8"}, + {file = "psutil-5.9.8-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:05806de88103b25903dff19bb6692bd2e714ccf9e668d050d144012055cbca73"}, + {file = "psutil-5.9.8-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:611052c4bc70432ec770d5d54f64206aa7203a101ec273a0cd82418c86503bb7"}, + {file = "psutil-5.9.8-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:50187900d73c1381ba1454cf40308c2bf6f34268518b3f36a9b663ca87e65e36"}, + {file = "psutil-5.9.8-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:02615ed8c5ea222323408ceba16c60e99c3f91639b07da6373fb7e6539abc56d"}, + {file = "psutil-5.9.8-cp27-none-win32.whl", hash = "sha256:36f435891adb138ed3c9e58c6af3e2e6ca9ac2f365efe1f9cfef2794e6c93b4e"}, + {file = "psutil-5.9.8-cp27-none-win_amd64.whl", hash = "sha256:bd1184ceb3f87651a67b2708d4c3338e9b10c5df903f2e3776b62303b26cb631"}, + {file = "psutil-5.9.8-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81"}, + {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421"}, + {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4"}, + {file = "psutil-5.9.8-cp36-cp36m-win32.whl", hash = "sha256:7d79560ad97af658a0f6adfef8b834b53f64746d45b403f225b85c5c2c140eee"}, + {file = "psutil-5.9.8-cp36-cp36m-win_amd64.whl", hash = "sha256:27cc40c3493bb10de1be4b3f07cae4c010ce715290a5be22b98493509c6299e2"}, + {file = "psutil-5.9.8-cp37-abi3-win32.whl", hash = "sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0"}, + {file = "psutil-5.9.8-cp37-abi3-win_amd64.whl", hash = "sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf"}, + {file = "psutil-5.9.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8"}, + {file = "psutil-5.9.8.tar.gz", hash = "sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c"}, ] [package.extras] @@ -2993,51 +3058,62 @@ tests = ["pytest"] [[package]] name = "pyarrow" -version = "14.0.2" +version = "15.0.2" description = "Python library for Apache Arrow" optional = false python-versions = ">=3.8" files = [ - {file = "pyarrow-14.0.2-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:ba9fe808596c5dbd08b3aeffe901e5f81095baaa28e7d5118e01354c64f22807"}, - {file = "pyarrow-14.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:22a768987a16bb46220cef490c56c671993fbee8fd0475febac0b3e16b00a10e"}, - {file = "pyarrow-14.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2dbba05e98f247f17e64303eb876f4a80fcd32f73c7e9ad975a83834d81f3fda"}, - {file = "pyarrow-14.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a898d134d00b1eca04998e9d286e19653f9d0fcb99587310cd10270907452a6b"}, - {file = "pyarrow-14.0.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:87e879323f256cb04267bb365add7208f302df942eb943c93a9dfeb8f44840b1"}, - {file = "pyarrow-14.0.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:76fc257559404ea5f1306ea9a3ff0541bf996ff3f7b9209fc517b5e83811fa8e"}, - {file = "pyarrow-14.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:b0c4a18e00f3a32398a7f31da47fefcd7a927545b396e1f15d0c85c2f2c778cd"}, - {file = "pyarrow-14.0.2-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:87482af32e5a0c0cce2d12eb3c039dd1d853bd905b04f3f953f147c7a196915b"}, - {file = "pyarrow-14.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:059bd8f12a70519e46cd64e1ba40e97eae55e0cbe1695edd95384653d7626b23"}, - {file = "pyarrow-14.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f16111f9ab27e60b391c5f6d197510e3ad6654e73857b4e394861fc79c37200"}, - {file = "pyarrow-14.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06ff1264fe4448e8d02073f5ce45a9f934c0f3db0a04460d0b01ff28befc3696"}, - {file = "pyarrow-14.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6dd4f4b472ccf4042f1eab77e6c8bce574543f54d2135c7e396f413046397d5a"}, - {file = "pyarrow-14.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:32356bfb58b36059773f49e4e214996888eeea3a08893e7dbde44753799b2a02"}, - {file = "pyarrow-14.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:52809ee69d4dbf2241c0e4366d949ba035cbcf48409bf404f071f624ed313a2b"}, - {file = "pyarrow-14.0.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:c87824a5ac52be210d32906c715f4ed7053d0180c1060ae3ff9b7e560f53f944"}, - {file = "pyarrow-14.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a25eb2421a58e861f6ca91f43339d215476f4fe159eca603c55950c14f378cc5"}, - {file = "pyarrow-14.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c1da70d668af5620b8ba0a23f229030a4cd6c5f24a616a146f30d2386fec422"}, - {file = "pyarrow-14.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2cc61593c8e66194c7cdfae594503e91b926a228fba40b5cf25cc593563bcd07"}, - {file = "pyarrow-14.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:78ea56f62fb7c0ae8ecb9afdd7893e3a7dbeb0b04106f5c08dbb23f9c0157591"}, - {file = "pyarrow-14.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:37c233ddbce0c67a76c0985612fef27c0c92aef9413cf5aa56952f359fcb7379"}, - {file = "pyarrow-14.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:e4b123ad0f6add92de898214d404e488167b87b5dd86e9a434126bc2b7a5578d"}, - {file = "pyarrow-14.0.2-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:e354fba8490de258be7687f341bc04aba181fc8aa1f71e4584f9890d9cb2dec2"}, - {file = "pyarrow-14.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:20e003a23a13da963f43e2b432483fdd8c38dc8882cd145f09f21792e1cf22a1"}, - {file = "pyarrow-14.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc0de7575e841f1595ac07e5bc631084fd06ca8b03c0f2ecece733d23cd5102a"}, - {file = "pyarrow-14.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66e986dc859712acb0bd45601229021f3ffcdfc49044b64c6d071aaf4fa49e98"}, - {file = "pyarrow-14.0.2-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:f7d029f20ef56673a9730766023459ece397a05001f4e4d13805111d7c2108c0"}, - {file = "pyarrow-14.0.2-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:209bac546942b0d8edc8debda248364f7f668e4aad4741bae58e67d40e5fcf75"}, - {file = "pyarrow-14.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:1e6987c5274fb87d66bb36816afb6f65707546b3c45c44c28e3c4133c010a881"}, - {file = "pyarrow-14.0.2-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:a01d0052d2a294a5f56cc1862933014e696aa08cc7b620e8c0cce5a5d362e976"}, - {file = "pyarrow-14.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a51fee3a7db4d37f8cda3ea96f32530620d43b0489d169b285d774da48ca9785"}, - {file = "pyarrow-14.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:64df2bf1ef2ef14cee531e2dfe03dd924017650ffaa6f9513d7a1bb291e59c15"}, - {file = "pyarrow-14.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c0fa3bfdb0305ffe09810f9d3e2e50a2787e3a07063001dcd7adae0cee3601a"}, - {file = "pyarrow-14.0.2-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:c65bf4fd06584f058420238bc47a316e80dda01ec0dfb3044594128a6c2db794"}, - {file = "pyarrow-14.0.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:63ac901baec9369d6aae1cbe6cca11178fb018a8d45068aaf5bb54f94804a866"}, - {file = "pyarrow-14.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:75ee0efe7a87a687ae303d63037d08a48ef9ea0127064df18267252cfe2e9541"}, - {file = "pyarrow-14.0.2.tar.gz", hash = "sha256:36cef6ba12b499d864d1def3e990f97949e0b79400d08b7cf74504ffbd3eb025"}, + {file = "pyarrow-15.0.2-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:88b340f0a1d05b5ccc3d2d986279045655b1fe8e41aba6ca44ea28da0d1455d8"}, + {file = "pyarrow-15.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:eaa8f96cecf32da508e6c7f69bb8401f03745c050c1dd42ec2596f2e98deecac"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23c6753ed4f6adb8461e7c383e418391b8d8453c5d67e17f416c3a5d5709afbd"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f639c059035011db8c0497e541a8a45d98a58dbe34dc8fadd0ef128f2cee46e5"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:290e36a59a0993e9a5224ed2fb3e53375770f07379a0ea03ee2fce2e6d30b423"}, + {file = "pyarrow-15.0.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:06c2bb2a98bc792f040bef31ad3e9be6a63d0cb39189227c08a7d955db96816e"}, + {file = "pyarrow-15.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:f7a197f3670606a960ddc12adbe8075cea5f707ad7bf0dffa09637fdbb89f76c"}, + {file = "pyarrow-15.0.2-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:5f8bc839ea36b1f99984c78e06e7a06054693dc2af8920f6fb416b5bca9944e4"}, + {file = "pyarrow-15.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f5e81dfb4e519baa6b4c80410421528c214427e77ca0ea9461eb4097c328fa33"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a4f240852b302a7af4646c8bfe9950c4691a419847001178662a98915fd7ee7"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e7d9cfb5a1e648e172428c7a42b744610956f3b70f524aa3a6c02a448ba853e"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2d4f905209de70c0eb5b2de6763104d5a9a37430f137678edfb9a675bac9cd98"}, + {file = "pyarrow-15.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:90adb99e8ce5f36fbecbbc422e7dcbcbed07d985eed6062e459e23f9e71fd197"}, + {file = "pyarrow-15.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:b116e7fd7889294cbd24eb90cd9bdd3850be3738d61297855a71ac3b8124ee38"}, + {file = "pyarrow-15.0.2-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:25335e6f1f07fdaa026a61c758ee7d19ce824a866b27bba744348fa73bb5a440"}, + {file = "pyarrow-15.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:90f19e976d9c3d8e73c80be84ddbe2f830b6304e4c576349d9360e335cd627fc"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a22366249bf5fd40ddacc4f03cd3160f2d7c247692945afb1899bab8a140ddfb"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2a335198f886b07e4b5ea16d08ee06557e07db54a8400cc0d03c7f6a22f785f"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3e6d459c0c22f0b9c810a3917a1de3ee704b021a5fb8b3bacf968eece6df098f"}, + {file = "pyarrow-15.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:033b7cad32198754d93465dcfb71d0ba7cb7cd5c9afd7052cab7214676eec38b"}, + {file = "pyarrow-15.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:29850d050379d6e8b5a693098f4de7fd6a2bea4365bfd073d7c57c57b95041ee"}, + {file = "pyarrow-15.0.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:7167107d7fb6dcadb375b4b691b7e316f4368f39f6f45405a05535d7ad5e5058"}, + {file = "pyarrow-15.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e85241b44cc3d365ef950432a1b3bd44ac54626f37b2e3a0cc89c20e45dfd8bf"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:248723e4ed3255fcd73edcecc209744d58a9ca852e4cf3d2577811b6d4b59818"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ff3bdfe6f1b81ca5b73b70a8d482d37a766433823e0c21e22d1d7dde76ca33f"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:f3d77463dee7e9f284ef42d341689b459a63ff2e75cee2b9302058d0d98fe142"}, + {file = "pyarrow-15.0.2-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:8c1faf2482fb89766e79745670cbca04e7018497d85be9242d5350cba21357e1"}, + {file = "pyarrow-15.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:28f3016958a8e45a1069303a4a4f6a7d4910643fc08adb1e2e4a7ff056272ad3"}, + {file = "pyarrow-15.0.2-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:89722cb64286ab3d4daf168386f6968c126057b8c7ec3ef96302e81d8cdb8ae4"}, + {file = "pyarrow-15.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cd0ba387705044b3ac77b1b317165c0498299b08261d8122c96051024f953cd5"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad2459bf1f22b6a5cdcc27ebfd99307d5526b62d217b984b9f5c974651398832"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58922e4bfece8b02abf7159f1f53a8f4d9f8e08f2d988109126c17c3bb261f22"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:adccc81d3dc0478ea0b498807b39a8d41628fa9210729b2f718b78cb997c7c91"}, + {file = "pyarrow-15.0.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:8bd2baa5fe531571847983f36a30ddbf65261ef23e496862ece83bdceb70420d"}, + {file = "pyarrow-15.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6669799a1d4ca9da9c7e06ef48368320f5856f36f9a4dd31a11839dda3f6cc8c"}, + {file = "pyarrow-15.0.2.tar.gz", hash = "sha256:9c9bc803cb3b7bfacc1e96ffbfd923601065d9d3f911179d81e72d99fd74a3d9"}, ] [package.dependencies] -numpy = ">=1.16.6" +numpy = ">=1.16.6,<2" + +[[package]] +name = "pyarrow-hotfix" +version = "0.6" +description = "" +optional = false +python-versions = ">=3.5" +files = [ + {file = "pyarrow_hotfix-0.6-py3-none-any.whl", hash = "sha256:dcc9ae2d220dff0083be6a9aa8e0cdee5182ad358d4931fce825c545e5c89178"}, + {file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"}, +] [[package]] name = "pycln" @@ -3085,13 +3161,13 @@ windows-terminal = ["colorama (>=0.4.6)"] [[package]] name = "pytest" -version = "7.4.4" +version = "8.1.1" description = "pytest: simple powerful testing with Python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"}, - {file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"}, + {file = "pytest-8.1.1-py3-none-any.whl", hash = "sha256:2a8386cfc11fa9d2c50ee7b2a57e7d898ef90470a7a34c4b949ff59662bb78b7"}, + {file = "pytest-8.1.1.tar.gz", hash = "sha256:ac978141a75948948817d360297b7aae0fcb9d6ff6bc9ec6d514b85d5a65c044"}, ] [package.dependencies] @@ -3099,21 +3175,21 @@ colorama = {version = "*", markers = "sys_platform == \"win32\""} exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" -pluggy = ">=0.12,<2.0" -tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} +pluggy = ">=1.4,<2.0" +tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] -testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] [[package]] name = "pytest-cov" -version = "4.1.0" +version = "5.0.0" description = "Pytest plugin for measuring coverage." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"}, - {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"}, + {file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"}, + {file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"}, ] [package.dependencies] @@ -3121,17 +3197,17 @@ coverage = {version = ">=5.2.1", extras = ["toml"]} pytest = ">=4.6" [package.extras] -testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] +testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] [[package]] name = "pytest-doctestplus" -version = "1.1.0" +version = "1.2.1" description = "Pytest plugin with advanced doctest features." optional = false python-versions = ">=3.8" files = [ - {file = "pytest-doctestplus-1.1.0.tar.gz", hash = "sha256:ea0a710f1b6a3571ed971fb6d6e5db05a2ae6b91b0fbcafe30fb5ea40e9987c4"}, - {file = "pytest_doctestplus-1.1.0-py3-none-any.whl", hash = "sha256:b98d95b4956a03256c638f1f9f72200160e9885ab1cd40f35c4453bc1d2e32b2"}, + {file = "pytest-doctestplus-1.2.1.tar.gz", hash = "sha256:2472a8a2c8cea34d2f65f6499543faeb748eecb59c597852fd98839b47307679"}, + {file = "pytest_doctestplus-1.2.1-py3-none-any.whl", hash = "sha256:103705daee8d4468eb59d444c29b0d71eb85b8f6d582295c8bc3d68ee1d88911"}, ] [package.dependencies] @@ -3144,13 +3220,13 @@ test = ["numpy", "pytest-remotedata (>=0.3.2)", "sphinx"] [[package]] name = "python-dateutil" -version = "2.8.2" +version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ - {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, - {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, + {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, + {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, ] [package.dependencies] @@ -3169,13 +3245,13 @@ files = [ [[package]] name = "pytz" -version = "2023.3.post1" +version = "2024.1" description = "World timezone definitions, modern and historical" optional = false python-versions = "*" files = [ - {file = "pytz-2023.3.post1-py2.py3-none-any.whl", hash = "sha256:ce42d816b81b68506614c11e8937d3aa9e41007ceb50bfdcb0749b921bf646c7"}, - {file = "pytz-2023.3.post1.tar.gz", hash = "sha256:7b4fddbeb94a1eba4b557da24f19fdf9db575192544270a9101d8509f9f43d7b"}, + {file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"}, + {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, ] [[package]] @@ -3203,17 +3279,17 @@ files = [ [[package]] name = "pywinpty" -version = "2.0.12" +version = "2.0.13" description = "Pseudo terminal support for Windows from Python." optional = false python-versions = ">=3.8" files = [ - {file = "pywinpty-2.0.12-cp310-none-win_amd64.whl", hash = "sha256:21319cd1d7c8844fb2c970fb3a55a3db5543f112ff9cfcd623746b9c47501575"}, - {file = "pywinpty-2.0.12-cp311-none-win_amd64.whl", hash = "sha256:853985a8f48f4731a716653170cd735da36ffbdc79dcb4c7b7140bce11d8c722"}, - {file = "pywinpty-2.0.12-cp312-none-win_amd64.whl", hash = "sha256:1617b729999eb6713590e17665052b1a6ae0ad76ee31e60b444147c5b6a35dca"}, - {file = "pywinpty-2.0.12-cp38-none-win_amd64.whl", hash = "sha256:189380469ca143d06e19e19ff3fba0fcefe8b4a8cc942140a6b863aed7eebb2d"}, - {file = "pywinpty-2.0.12-cp39-none-win_amd64.whl", hash = "sha256:7520575b6546db23e693cbd865db2764097bd6d4ef5dc18c92555904cd62c3d4"}, - {file = "pywinpty-2.0.12.tar.gz", hash = "sha256:8197de460ae8ebb7f5d1701dfa1b5df45b157bb832e92acba316305e18ca00dd"}, + {file = "pywinpty-2.0.13-cp310-none-win_amd64.whl", hash = "sha256:697bff211fb5a6508fee2dc6ff174ce03f34a9a233df9d8b5fe9c8ce4d5eaf56"}, + {file = "pywinpty-2.0.13-cp311-none-win_amd64.whl", hash = "sha256:b96fb14698db1284db84ca38c79f15b4cfdc3172065b5137383910567591fa99"}, + {file = "pywinpty-2.0.13-cp312-none-win_amd64.whl", hash = "sha256:2fd876b82ca750bb1333236ce98488c1be96b08f4f7647cfdf4129dfad83c2d4"}, + {file = "pywinpty-2.0.13-cp38-none-win_amd64.whl", hash = "sha256:61d420c2116c0212808d31625611b51caf621fe67f8a6377e2e8b617ea1c1f7d"}, + {file = "pywinpty-2.0.13-cp39-none-win_amd64.whl", hash = "sha256:71cb613a9ee24174730ac7ae439fd179ca34ccb8c5349e8d7b72ab5dea2c6f4b"}, + {file = "pywinpty-2.0.13.tar.gz", hash = "sha256:c34e32351a3313ddd0d7da23d27f835c860d32fe4ac814d372a3ea9594f41dde"}, ] [[package]] @@ -3425,13 +3501,13 @@ test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"] [[package]] name = "referencing" -version = "0.32.1" +version = "0.34.0" description = "JSON Referencing + Python" optional = false python-versions = ">=3.8" files = [ - {file = "referencing-0.32.1-py3-none-any.whl", hash = "sha256:7e4dc12271d8e15612bfe35792f5ea1c40970dadf8624602e33db2758f7ee554"}, - {file = "referencing-0.32.1.tar.gz", hash = "sha256:3c57da0513e9563eb7e203ebe9bb3a1b509b042016433bd1e45a2853466c3dd3"}, + {file = "referencing-0.34.0-py3-none-any.whl", hash = "sha256:d53ae300ceddd3169f1ffa9caf2cb7b769e92657e4fafb23d34b93679116dfd4"}, + {file = "referencing-0.34.0.tar.gz", hash = "sha256:5773bd84ef41799a5a8ca72dc34590c041eb01bf9aa02632b4a973fb0181a844"}, ] [package.dependencies] @@ -3588,13 +3664,13 @@ files = [ [[package]] name = "rich" -version = "13.7.0" +version = "13.7.1" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" optional = false python-versions = ">=3.7.0" files = [ - {file = "rich-13.7.0-py3-none-any.whl", hash = "sha256:6da14c108c4866ee9520bbffa71f6fe3962e193b7da68720583850cd4548e235"}, - {file = "rich-13.7.0.tar.gz", hash = "sha256:5cb5123b5cf9ee70584244246816e9114227e0b98ad9176eede6ad54bf5403fa"}, + {file = "rich-13.7.1-py3-none-any.whl", hash = "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222"}, + {file = "rich-13.7.1.tar.gz", hash = "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432"}, ] [package.dependencies] @@ -3607,223 +3683,236 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] [[package]] name = "rpds-py" -version = "0.16.2" +version = "0.18.0" description = "Python bindings to Rust's persistent data structures (rpds)" optional = false python-versions = ">=3.8" files = [ - {file = "rpds_py-0.16.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:509b617ac787cd1149600e731db9274ebbef094503ca25158e6f23edaba1ca8f"}, - {file = "rpds_py-0.16.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:413b9c17388bbd0d87a329d8e30c1a4c6e44e2bb25457f43725a8e6fe4161e9e"}, - {file = "rpds_py-0.16.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2946b120718eba9af2b4dd103affc1164a87b9e9ebff8c3e4c05d7b7a7e274e2"}, - {file = "rpds_py-0.16.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:35ae5ece284cf36464eb160880018cf6088a9ac5ddc72292a6092b6ef3f4da53"}, - {file = "rpds_py-0.16.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3dc6a7620ba7639a3db6213da61312cb4aa9ac0ca6e00dc1cbbdc21c2aa6eb57"}, - {file = "rpds_py-0.16.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8cb6fe8ecdfffa0e711a75c931fb39f4ba382b4b3ccedeca43f18693864fe850"}, - {file = "rpds_py-0.16.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6dace7b26a13353e24613417ce2239491b40a6ad44e5776a18eaff7733488b44"}, - {file = "rpds_py-0.16.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1bdbc5fcb04a7309074de6b67fa9bc4b418ab3fc435fec1f2779a0eced688d04"}, - {file = "rpds_py-0.16.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f42e25c016927e2a6b1ce748112c3ab134261fc2ddc867e92d02006103e1b1b7"}, - {file = "rpds_py-0.16.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:eab36eae3f3e8e24b05748ec9acc66286662f5d25c52ad70cadab544e034536b"}, - {file = "rpds_py-0.16.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:0474df4ade9a3b4af96c3d36eb81856cb9462e4c6657d4caecfd840d2a13f3c9"}, - {file = "rpds_py-0.16.2-cp310-none-win32.whl", hash = "sha256:84c5a4d1f9dd7e2d2c44097fb09fffe728629bad31eb56caf97719e55575aa82"}, - {file = "rpds_py-0.16.2-cp310-none-win_amd64.whl", hash = "sha256:2bd82db36cd70b3628c0c57d81d2438e8dd4b7b32a6a9f25f24ab0e657cb6c4e"}, - {file = "rpds_py-0.16.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:adc0c3d6fc6ae35fee3e4917628983f6ce630d513cbaad575b4517d47e81b4bb"}, - {file = "rpds_py-0.16.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ec23fcad480e77ede06cf4127a25fc440f7489922e17fc058f426b5256ee0edb"}, - {file = "rpds_py-0.16.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07aab64e2808c3ebac2a44f67e9dc0543812b715126dfd6fe4264df527556cb6"}, - {file = "rpds_py-0.16.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a4ebb8b20bd09c5ce7884c8f0388801100f5e75e7f733b1b6613c713371feefc"}, - {file = "rpds_py-0.16.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a3d7e2ea25d3517c6d7e5a1cc3702cffa6bd18d9ef8d08d9af6717fc1c700eed"}, - {file = "rpds_py-0.16.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f28ac0e8e7242d140f99402a903a2c596ab71550272ae9247ad78f9a932b5698"}, - {file = "rpds_py-0.16.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19f00f57fdd38db4bb5ad09f9ead1b535332dbf624200e9029a45f1f35527ebb"}, - {file = "rpds_py-0.16.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3da5a4c56953bdbf6d04447c3410309616c54433146ccdb4a277b9cb499bc10e"}, - {file = "rpds_py-0.16.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ec2e1cf025b2c0f48ec17ff3e642661da7ee332d326f2e6619366ce8e221f018"}, - {file = "rpds_py-0.16.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e0441fb4fdd39a230477b2ca9be90868af64425bfe7b122b57e61e45737a653b"}, - {file = "rpds_py-0.16.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9f0350ef2fba5f34eb0c9000ea328e51b9572b403d2f7f3b19f24085f6f598e8"}, - {file = "rpds_py-0.16.2-cp311-none-win32.whl", hash = "sha256:5a80e2f83391ad0808b4646732af2a7b67550b98f0cae056cb3b40622a83dbb3"}, - {file = "rpds_py-0.16.2-cp311-none-win_amd64.whl", hash = "sha256:e04e56b4ca7a770593633556e8e9e46579d66ec2ada846b401252a2bdcf70a6d"}, - {file = "rpds_py-0.16.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:5e6caa3809e50690bd92fa490f5c38caa86082c8c3315aa438bce43786d5e90d"}, - {file = "rpds_py-0.16.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2e53b9b25cac9065328901713a7e9e3b12e4f57ef4280b370fbbf6fef2052eef"}, - {file = "rpds_py-0.16.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af27423662f32d7501a00c5e7342f7dbd1e4a718aea7a239781357d15d437133"}, - {file = "rpds_py-0.16.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:43d4dd5fb16eb3825742bad8339d454054261ab59fed2fbac84e1d84d5aae7ba"}, - {file = "rpds_py-0.16.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e061de3b745fe611e23cd7318aec2c8b0e4153939c25c9202a5811ca911fd733"}, - {file = "rpds_py-0.16.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b811d182ad17ea294f2ec63c0621e7be92a1141e1012383461872cead87468f"}, - {file = "rpds_py-0.16.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5552f328eaef1a75ff129d4d0c437bf44e43f9436d3996e8eab623ea0f5fcf73"}, - {file = "rpds_py-0.16.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:dcbe1f8dd179e4d69b70b1f1d9bb6fd1e7e1bdc9c9aad345cdeb332e29d40748"}, - {file = "rpds_py-0.16.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8aad80645a011abae487d356e0ceb359f4938dfb6f7bcc410027ed7ae4f7bb8b"}, - {file = "rpds_py-0.16.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b6f5549d6ed1da9bfe3631ca9483ae906f21410be2445b73443fa9f017601c6f"}, - {file = "rpds_py-0.16.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d452817e0d9c749c431a1121d56a777bd7099b720b3d1c820f1725cb40928f58"}, - {file = "rpds_py-0.16.2-cp312-none-win32.whl", hash = "sha256:888a97002e986eca10d8546e3c8b97da1d47ad8b69726dcfeb3e56348ebb28a3"}, - {file = "rpds_py-0.16.2-cp312-none-win_amd64.whl", hash = "sha256:d8dda2a806dfa4a9b795950c4f5cc56d6d6159f7d68080aedaff3bdc9b5032f5"}, - {file = "rpds_py-0.16.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:071980663c273bf3d388fe5c794c547e6f35ba3335477072c713a3176bf14a60"}, - {file = "rpds_py-0.16.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:726ac36e8a3bb8daef2fd482534cabc5e17334052447008405daca7ca04a3108"}, - {file = "rpds_py-0.16.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e9e557db6a177470316c82f023e5d571811c9a4422b5ea084c85da9aa3c035fc"}, - {file = "rpds_py-0.16.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:90123853fc8b1747f80b0d354be3d122b4365a93e50fc3aacc9fb4c2488845d6"}, - {file = "rpds_py-0.16.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a61f659665a39a4d17d699ab3593d7116d66e1e2e3f03ef3fb8f484e91908808"}, - {file = "rpds_py-0.16.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cc97f0640e91d7776530f06e6836c546c1c752a52de158720c4224c9e8053cad"}, - {file = "rpds_py-0.16.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44a54e99a2b9693a37ebf245937fd6e9228b4cbd64b9cc961e1f3391ec6c7391"}, - {file = "rpds_py-0.16.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bd4b677d929cf1f6bac07ad76e0f2d5de367e6373351c01a9c0a39f6b21b4a8b"}, - {file = "rpds_py-0.16.2-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:5ef00873303d678aaf8b0627e111fd434925ca01c657dbb2641410f1cdaef261"}, - {file = "rpds_py-0.16.2-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:349cb40897fd529ca15317c22c0eab67f5ac5178b5bd2c6adc86172045210acc"}, - {file = "rpds_py-0.16.2-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:2ddef620e70eaffebed5932ce754d539c0930f676aae6212f8e16cd9743dd365"}, - {file = "rpds_py-0.16.2-cp38-none-win32.whl", hash = "sha256:882ce6e25e585949c3d9f9abd29202367175e0aab3aba0c58c9abbb37d4982ff"}, - {file = "rpds_py-0.16.2-cp38-none-win_amd64.whl", hash = "sha256:f4bd4578e44f26997e9e56c96dedc5f1af43cc9d16c4daa29c771a00b2a26851"}, - {file = "rpds_py-0.16.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:69ac7ea9897ec201ce68b48582f3eb34a3f9924488a5432a93f177bf76a82a7e"}, - {file = "rpds_py-0.16.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a9880b4656efe36ccad41edc66789e191e5ee19a1ea8811e0aed6f69851a82f4"}, - {file = "rpds_py-0.16.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee94cb58c0ba2c62ee108c2b7c9131b2c66a29e82746e8fa3aa1a1effbd3dcf1"}, - {file = "rpds_py-0.16.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:24f7a2eb3866a9e91f4599851e0c8d39878a470044875c49bd528d2b9b88361c"}, - {file = "rpds_py-0.16.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ca57468da2d9a660bcf8961637c85f2fbb2aa64d9bc3f9484e30c3f9f67b1dd7"}, - {file = "rpds_py-0.16.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ccd4e400309e1f34a5095bf9249d371f0fd60f8a3a5c4a791cad7b99ce1fd38d"}, - {file = "rpds_py-0.16.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80443fe2f7b3ea3934c5d75fb0e04a5dbb4a8e943e5ff2de0dec059202b70a8b"}, - {file = "rpds_py-0.16.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4d6a9f052e72d493efd92a77f861e45bab2f6be63e37fa8ecf0c6fd1a58fedb0"}, - {file = "rpds_py-0.16.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:35953f4f2b3216421af86fd236b7c0c65935936a94ea83ddbd4904ba60757773"}, - {file = "rpds_py-0.16.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:981d135c7cdaf6cd8eadae1c950de43b976de8f09d8e800feed307140d3d6d00"}, - {file = "rpds_py-0.16.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d0dd7ed2f16df2e129496e7fbe59a34bc2d7fc8db443a606644d069eb69cbd45"}, - {file = "rpds_py-0.16.2-cp39-none-win32.whl", hash = "sha256:703d95c75a72e902544fda08e965885525e297578317989fd15a6ce58414b41d"}, - {file = "rpds_py-0.16.2-cp39-none-win_amd64.whl", hash = "sha256:e93ec1b300acf89730cf27975ef574396bc04edecc358e9bd116fb387a123239"}, - {file = "rpds_py-0.16.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:44627b6ca7308680a70766454db5249105fa6344853af6762eaad4158a2feebe"}, - {file = "rpds_py-0.16.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:3f91df8e6dbb7360e176d1affd5fb0246d2b88d16aa5ebc7db94fd66b68b61da"}, - {file = "rpds_py-0.16.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d904c5693e08bad240f16d79305edba78276be87061c872a4a15e2c301fa2c0"}, - {file = "rpds_py-0.16.2-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:290a81cfbe4673285cdf140ec5cd1658ffbf63ab359f2b352ebe172e7cfa5bf0"}, - {file = "rpds_py-0.16.2-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b634c5ec0103c5cbebc24ebac4872b045cccb9456fc59efdcf6fe39775365bd2"}, - {file = "rpds_py-0.16.2-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a297a4d08cc67c7466c873c78039d87840fb50d05473db0ec1b7b03d179bf322"}, - {file = "rpds_py-0.16.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2e75e17bd0bb66ee34a707da677e47c14ee51ccef78ed6a263a4cc965a072a1"}, - {file = "rpds_py-0.16.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f1b9d9260e06ea017feb7172976ab261e011c1dc2f8883c7c274f6b2aabfe01a"}, - {file = "rpds_py-0.16.2-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:162d7cd9cd311c1b0ff1c55a024b8f38bd8aad1876b648821da08adc40e95734"}, - {file = "rpds_py-0.16.2-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:9b32f742ce5b57201305f19c2ef7a184b52f6f9ba6871cc042c2a61f0d6b49b8"}, - {file = "rpds_py-0.16.2-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:ac08472f41ea77cd6a5dae36ae7d4ed3951d6602833af87532b556c1b4601d63"}, - {file = "rpds_py-0.16.2-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:495a14b72bbe217f2695dcd9b5ab14d4f8066a00f5d209ed94f0aca307f85f6e"}, - {file = "rpds_py-0.16.2-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:8d6b6937ae9eac6d6c0ca3c42774d89fa311f55adff3970fb364b34abde6ed3d"}, - {file = "rpds_py-0.16.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a61226465bda9283686db8f17d02569a98e4b13c637be5a26d44aa1f1e361c2"}, - {file = "rpds_py-0.16.2-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5cf6af100ffb5c195beec11ffaa8cf8523057f123afa2944e6571d54da84cdc9"}, - {file = "rpds_py-0.16.2-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6df15846ee3fb2e6397fe25d7ca6624af9f89587f3f259d177b556fed6bebe2c"}, - {file = "rpds_py-0.16.2-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1be2f033df1b8be8c3167ba3c29d5dca425592ee31e35eac52050623afba5772"}, - {file = "rpds_py-0.16.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96f957d6ab25a78b9e7fc9749d754b98eac825a112b4e666525ce89afcbd9ed5"}, - {file = "rpds_py-0.16.2-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:088396c7c70e59872f67462fcac3ecbded5233385797021976a09ebd55961dfe"}, - {file = "rpds_py-0.16.2-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:4c46ad6356e1561f2a54f08367d1d2e70a0a1bb2db2282d2c1972c1d38eafc3b"}, - {file = "rpds_py-0.16.2-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:47713dc4fce213f5c74ca8a1f6a59b622fc1b90868deb8e8e4d993e421b4b39d"}, - {file = "rpds_py-0.16.2-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:f811771019f063bbd0aa7bb72c8a934bc13ebacb4672d712fc1639cfd314cccc"}, - {file = "rpds_py-0.16.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f19afcfc0dd0dca35694df441e9b0f95bc231b512f51bded3c3d8ca32153ec19"}, - {file = "rpds_py-0.16.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:a4b682c5775d6a3d21e314c10124599976809455ee67020e8e72df1769b87bc3"}, - {file = "rpds_py-0.16.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c647ca87fc0ebe808a41de912e9a1bfef9acb85257e5d63691364ac16b81c1f0"}, - {file = "rpds_py-0.16.2-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:302bd4983bbd47063e452c38be66153760112f6d3635c7eeefc094299fa400a9"}, - {file = "rpds_py-0.16.2-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bf721ede3eb7b829e4a9b8142bd55db0bdc82902720548a703f7e601ee13bdc3"}, - {file = "rpds_py-0.16.2-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:358dafc89ce3894c7f486c615ba914609f38277ef67f566abc4c854d23b997fa"}, - {file = "rpds_py-0.16.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cad0f59ee3dc35526039f4bc23642d52d5f6616b5f687d846bfc6d0d6d486db0"}, - {file = "rpds_py-0.16.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cffa76b385dfe1e38527662a302b19ffb0e7f5cf7dd5e89186d2c94a22dd9d0c"}, - {file = "rpds_py-0.16.2-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:83640a5d7cd3bff694747d50436b8b541b5b9b9782b0c8c1688931d6ee1a1f2d"}, - {file = "rpds_py-0.16.2-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:ed99b4f7179d2111702020fd7d156e88acd533f5a7d3971353e568b6051d5c97"}, - {file = "rpds_py-0.16.2-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:4022b9dc620e14f30201a8a73898a873c8e910cb642bcd2f3411123bc527f6ac"}, - {file = "rpds_py-0.16.2.tar.gz", hash = "sha256:781ef8bfc091b19960fc0142a23aedadafa826bc32b433fdfe6fd7f964d7ef44"}, + {file = "rpds_py-0.18.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:5b4e7d8d6c9b2e8ee2d55c90b59c707ca59bc30058269b3db7b1f8df5763557e"}, + {file = "rpds_py-0.18.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c463ed05f9dfb9baebef68048aed8dcdc94411e4bf3d33a39ba97e271624f8f7"}, + {file = "rpds_py-0.18.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:01e36a39af54a30f28b73096dd39b6802eddd04c90dbe161c1b8dbe22353189f"}, + {file = "rpds_py-0.18.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d62dec4976954a23d7f91f2f4530852b0c7608116c257833922a896101336c51"}, + {file = "rpds_py-0.18.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dd18772815d5f008fa03d2b9a681ae38d5ae9f0e599f7dda233c439fcaa00d40"}, + {file = "rpds_py-0.18.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:923d39efa3cfb7279a0327e337a7958bff00cc447fd07a25cddb0a1cc9a6d2da"}, + {file = "rpds_py-0.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39514da80f971362f9267c600b6d459bfbbc549cffc2cef8e47474fddc9b45b1"}, + {file = "rpds_py-0.18.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a34d557a42aa28bd5c48a023c570219ba2593bcbbb8dc1b98d8cf5d529ab1434"}, + {file = "rpds_py-0.18.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:93df1de2f7f7239dc9cc5a4a12408ee1598725036bd2dedadc14d94525192fc3"}, + {file = "rpds_py-0.18.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:34b18ba135c687f4dac449aa5157d36e2cbb7c03cbea4ddbd88604e076aa836e"}, + {file = "rpds_py-0.18.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c0b5dcf9193625afd8ecc92312d6ed78781c46ecbf39af9ad4681fc9f464af88"}, + {file = "rpds_py-0.18.0-cp310-none-win32.whl", hash = "sha256:c4325ff0442a12113a6379af66978c3fe562f846763287ef66bdc1d57925d337"}, + {file = "rpds_py-0.18.0-cp310-none-win_amd64.whl", hash = "sha256:7223a2a5fe0d217e60a60cdae28d6949140dde9c3bcc714063c5b463065e3d66"}, + {file = "rpds_py-0.18.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:3a96e0c6a41dcdba3a0a581bbf6c44bb863f27c541547fb4b9711fd8cf0ffad4"}, + {file = "rpds_py-0.18.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:30f43887bbae0d49113cbaab729a112251a940e9b274536613097ab8b4899cf6"}, + {file = "rpds_py-0.18.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fcb25daa9219b4cf3a0ab24b0eb9a5cc8949ed4dc72acb8fa16b7e1681aa3c58"}, + {file = "rpds_py-0.18.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d68c93e381010662ab873fea609bf6c0f428b6d0bb00f2c6939782e0818d37bf"}, + {file = "rpds_py-0.18.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b34b7aa8b261c1dbf7720b5d6f01f38243e9b9daf7e6b8bc1fd4657000062f2c"}, + {file = "rpds_py-0.18.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2e6d75ab12b0bbab7215e5d40f1e5b738aa539598db27ef83b2ec46747df90e1"}, + {file = "rpds_py-0.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b8612cd233543a3781bc659c731b9d607de65890085098986dfd573fc2befe5"}, + {file = "rpds_py-0.18.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:aec493917dd45e3c69d00a8874e7cbed844efd935595ef78a0f25f14312e33c6"}, + {file = "rpds_py-0.18.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:661d25cbffaf8cc42e971dd570d87cb29a665f49f4abe1f9e76be9a5182c4688"}, + {file = "rpds_py-0.18.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1df3659d26f539ac74fb3b0c481cdf9d725386e3552c6fa2974f4d33d78e544b"}, + {file = "rpds_py-0.18.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a1ce3ba137ed54f83e56fb983a5859a27d43a40188ba798993812fed73c70836"}, + {file = "rpds_py-0.18.0-cp311-none-win32.whl", hash = "sha256:69e64831e22a6b377772e7fb337533c365085b31619005802a79242fee620bc1"}, + {file = "rpds_py-0.18.0-cp311-none-win_amd64.whl", hash = "sha256:998e33ad22dc7ec7e030b3df701c43630b5bc0d8fbc2267653577e3fec279afa"}, + {file = "rpds_py-0.18.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:7f2facbd386dd60cbbf1a794181e6aa0bd429bd78bfdf775436020172e2a23f0"}, + {file = "rpds_py-0.18.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1d9a5be316c15ffb2b3c405c4ff14448c36b4435be062a7f578ccd8b01f0c4d8"}, + {file = "rpds_py-0.18.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cd5bf1af8efe569654bbef5a3e0a56eca45f87cfcffab31dd8dde70da5982475"}, + {file = "rpds_py-0.18.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5417558f6887e9b6b65b4527232553c139b57ec42c64570569b155262ac0754f"}, + {file = "rpds_py-0.18.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:56a737287efecafc16f6d067c2ea0117abadcd078d58721f967952db329a3e5c"}, + {file = "rpds_py-0.18.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8f03bccbd8586e9dd37219bce4d4e0d3ab492e6b3b533e973fa08a112cb2ffc9"}, + {file = "rpds_py-0.18.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4457a94da0d5c53dc4b3e4de1158bdab077db23c53232f37a3cb7afdb053a4e3"}, + {file = "rpds_py-0.18.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0ab39c1ba9023914297dd88ec3b3b3c3f33671baeb6acf82ad7ce883f6e8e157"}, + {file = "rpds_py-0.18.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:9d54553c1136b50fd12cc17e5b11ad07374c316df307e4cfd6441bea5fb68496"}, + {file = "rpds_py-0.18.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0af039631b6de0397ab2ba16eaf2872e9f8fca391b44d3d8cac317860a700a3f"}, + {file = "rpds_py-0.18.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:84ffab12db93b5f6bad84c712c92060a2d321b35c3c9960b43d08d0f639d60d7"}, + {file = "rpds_py-0.18.0-cp312-none-win32.whl", hash = "sha256:685537e07897f173abcf67258bee3c05c374fa6fff89d4c7e42fb391b0605e98"}, + {file = "rpds_py-0.18.0-cp312-none-win_amd64.whl", hash = "sha256:e003b002ec72c8d5a3e3da2989c7d6065b47d9eaa70cd8808b5384fbb970f4ec"}, + {file = "rpds_py-0.18.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:08f9ad53c3f31dfb4baa00da22f1e862900f45908383c062c27628754af2e88e"}, + {file = "rpds_py-0.18.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c0013fe6b46aa496a6749c77e00a3eb07952832ad6166bd481c74bda0dcb6d58"}, + {file = "rpds_py-0.18.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e32a92116d4f2a80b629778280103d2a510a5b3f6314ceccd6e38006b5e92dcb"}, + {file = "rpds_py-0.18.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e541ec6f2ec456934fd279a3120f856cd0aedd209fc3852eca563f81738f6861"}, + {file = "rpds_py-0.18.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bed88b9a458e354014d662d47e7a5baafd7ff81c780fd91584a10d6ec842cb73"}, + {file = "rpds_py-0.18.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2644e47de560eb7bd55c20fc59f6daa04682655c58d08185a9b95c1970fa1e07"}, + {file = "rpds_py-0.18.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e8916ae4c720529e18afa0b879473049e95949bf97042e938530e072fde061d"}, + {file = "rpds_py-0.18.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:465a3eb5659338cf2a9243e50ad9b2296fa15061736d6e26240e713522b6235c"}, + {file = "rpds_py-0.18.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:ea7d4a99f3b38c37eac212dbd6ec42b7a5ec51e2c74b5d3223e43c811609e65f"}, + {file = "rpds_py-0.18.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:67071a6171e92b6da534b8ae326505f7c18022c6f19072a81dcf40db2638767c"}, + {file = "rpds_py-0.18.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:41ef53e7c58aa4ef281da975f62c258950f54b76ec8e45941e93a3d1d8580594"}, + {file = "rpds_py-0.18.0-cp38-none-win32.whl", hash = "sha256:fdea4952db2793c4ad0bdccd27c1d8fdd1423a92f04598bc39425bcc2b8ee46e"}, + {file = "rpds_py-0.18.0-cp38-none-win_amd64.whl", hash = "sha256:7cd863afe7336c62ec78d7d1349a2f34c007a3cc6c2369d667c65aeec412a5b1"}, + {file = "rpds_py-0.18.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5307def11a35f5ae4581a0b658b0af8178c65c530e94893345bebf41cc139d33"}, + {file = "rpds_py-0.18.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:77f195baa60a54ef9d2de16fbbfd3ff8b04edc0c0140a761b56c267ac11aa467"}, + {file = "rpds_py-0.18.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39f5441553f1c2aed4de4377178ad8ff8f9d733723d6c66d983d75341de265ab"}, + {file = "rpds_py-0.18.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9a00312dea9310d4cb7dbd7787e722d2e86a95c2db92fbd7d0155f97127bcb40"}, + {file = "rpds_py-0.18.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8f2fc11e8fe034ee3c34d316d0ad8808f45bc3b9ce5857ff29d513f3ff2923a1"}, + {file = "rpds_py-0.18.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:586f8204935b9ec884500498ccc91aa869fc652c40c093bd9e1471fbcc25c022"}, + {file = "rpds_py-0.18.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddc2f4dfd396c7bfa18e6ce371cba60e4cf9d2e5cdb71376aa2da264605b60b9"}, + {file = "rpds_py-0.18.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5ddcba87675b6d509139d1b521e0c8250e967e63b5909a7e8f8944d0f90ff36f"}, + {file = "rpds_py-0.18.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:7bd339195d84439cbe5771546fe8a4e8a7a045417d8f9de9a368c434e42a721e"}, + {file = "rpds_py-0.18.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:d7c36232a90d4755b720fbd76739d8891732b18cf240a9c645d75f00639a9024"}, + {file = "rpds_py-0.18.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:6b0817e34942b2ca527b0e9298373e7cc75f429e8da2055607f4931fded23e20"}, + {file = "rpds_py-0.18.0-cp39-none-win32.whl", hash = "sha256:99f70b740dc04d09e6b2699b675874367885217a2e9f782bdf5395632ac663b7"}, + {file = "rpds_py-0.18.0-cp39-none-win_amd64.whl", hash = "sha256:6ef687afab047554a2d366e112dd187b62d261d49eb79b77e386f94644363294"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ad36cfb355e24f1bd37cac88c112cd7730873f20fb0bdaf8ba59eedf8216079f"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:36b3ee798c58ace201289024b52788161e1ea133e4ac93fba7d49da5fec0ef9e"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8a2f084546cc59ea99fda8e070be2fd140c3092dc11524a71aa8f0f3d5a55ca"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e4461d0f003a0aa9be2bdd1b798a041f177189c1a0f7619fe8c95ad08d9a45d7"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8db715ebe3bb7d86d77ac1826f7d67ec11a70dbd2376b7cc214199360517b641"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:793968759cd0d96cac1e367afd70c235867831983f876a53389ad869b043c948"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66e6a3af5a75363d2c9a48b07cb27c4ea542938b1a2e93b15a503cdfa8490795"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6ef0befbb5d79cf32d0266f5cff01545602344eda89480e1dd88aca964260b18"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:1d4acf42190d449d5e89654d5c1ed3a4f17925eec71f05e2a41414689cda02d1"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:a5f446dd5055667aabaee78487f2b5ab72e244f9bc0b2ffebfeec79051679984"}, + {file = "rpds_py-0.18.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:9dbbeb27f4e70bfd9eec1be5477517365afe05a9b2c441a0b21929ee61048124"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:22806714311a69fd0af9b35b7be97c18a0fc2826e6827dbb3a8c94eac6cf7eeb"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:b34ae4636dfc4e76a438ab826a0d1eed2589ca7d9a1b2d5bb546978ac6485461"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c8370641f1a7f0e0669ddccca22f1da893cef7628396431eb445d46d893e5cd"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c8362467a0fdeccd47935f22c256bec5e6abe543bf0d66e3d3d57a8fb5731863"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11a8c85ef4a07a7638180bf04fe189d12757c696eb41f310d2426895356dcf05"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b316144e85316da2723f9d8dc75bada12fa58489a527091fa1d5a612643d1a0e"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf1ea2e34868f6fbf070e1af291c8180480310173de0b0c43fc38a02929fc0e3"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e546e768d08ad55b20b11dbb78a745151acbd938f8f00d0cfbabe8b0199b9880"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:4901165d170a5fde6f589acb90a6b33629ad1ec976d4529e769c6f3d885e3e80"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-musllinux_1_2_i686.whl", hash = "sha256:618a3d6cae6ef8ec88bb76dd80b83cfe415ad4f1d942ca2a903bf6b6ff97a2da"}, + {file = "rpds_py-0.18.0-pp38-pypy38_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:ed4eb745efbff0a8e9587d22a84be94a5eb7d2d99c02dacf7bd0911713ed14dd"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6c81e5f372cd0dc5dc4809553d34f832f60a46034a5f187756d9b90586c2c307"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:43fbac5f22e25bee1d482c97474f930a353542855f05c1161fd804c9dc74a09d"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d7faa6f14017c0b1e69f5e2c357b998731ea75a442ab3841c0dbbbfe902d2c4"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:08231ac30a842bd04daabc4d71fddd7e6d26189406d5a69535638e4dcb88fe76"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:044a3e61a7c2dafacae99d1e722cc2d4c05280790ec5a05031b3876809d89a5c"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3f26b5bd1079acdb0c7a5645e350fe54d16b17bfc5e71f371c449383d3342e17"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:482103aed1dfe2f3b71a58eff35ba105289b8d862551ea576bd15479aba01f66"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1374f4129f9bcca53a1bba0bb86bf78325a0374577cf7e9e4cd046b1e6f20e24"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:635dc434ff724b178cb192c70016cc0ad25a275228f749ee0daf0eddbc8183b1"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:bc362ee4e314870a70f4ae88772d72d877246537d9f8cb8f7eacf10884862432"}, + {file = "rpds_py-0.18.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:4832d7d380477521a8c1644bbab6588dfedea5e30a7d967b5fb75977c45fd77f"}, + {file = "rpds_py-0.18.0.tar.gz", hash = "sha256:42821446ee7a76f5d9f71f9e33a4fb2ffd724bb3e7f93386150b61a43115788d"}, ] [[package]] name = "safetensors" -version = "0.4.1" +version = "0.4.2" description = "" optional = false python-versions = ">=3.7" files = [ - {file = "safetensors-0.4.1-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:cba01c6b76e01ec453933b3b3c0157c59b52881c83eaa0f7666244e71aa75fd1"}, - {file = "safetensors-0.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7a8f6f679d97ea0135c7935c202feefbd042c149aa70ee759855e890c01c7814"}, - {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bbc2ce1f5ae5143a7fb72b71fa71db6a42b4f6cf912aa3acdc6b914084778e68"}, - {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2d87d993eaefe6611a9c241a8bd364a5f1ffed5771c74840363a6c4ed8d868f6"}, - {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:097e9af2efa8778cd2f0cba451784253e62fa7cc9fc73c0744d27212f7294e25"}, - {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d10a9f7bae608ccfdc009351f01dc3d8535ff57f9488a58a4c38e45bf954fe93"}, - {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:270b99885ec14abfd56c1d7f28ada81740a9220b4bae960c3de1c6fe84af9e4d"}, - {file = "safetensors-0.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:285b52a481e7ba93e29ad4ec5841ef2c4479ef0a6c633c4e2629e0508453577b"}, - {file = "safetensors-0.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c3c9f0ca510e0de95abd6424789dcbc879942a3a4e29b0dfa99d9427bf1da75c"}, - {file = "safetensors-0.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:88b4653059c903015284a9722f9a46838c654257173b279c8f6f46dbe80b612d"}, - {file = "safetensors-0.4.1-cp310-none-win32.whl", hash = "sha256:2fe6926110e3d425c4b684a4379b7796fdc26ad7d16922ea1696c8e6ea7e920f"}, - {file = "safetensors-0.4.1-cp310-none-win_amd64.whl", hash = "sha256:a79e16222106b2f5edbca1b8185661477d8971b659a3c814cc6f15181a9b34c8"}, - {file = "safetensors-0.4.1-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:d93321eea0dd7e81b283e47a1d20dee6069165cc158286316d0d06d340de8fe8"}, - {file = "safetensors-0.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8ff8e41c8037db17de0ea2a23bc684f43eaf623be7d34906fe1ac10985b8365e"}, - {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39d36f1d88468a87c437a1bc27c502e71b6ca44c385a9117a9f9ba03a75cc9c6"}, - {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7ef010e9afcb4057fb6be3d0a0cfa07aac04fe97ef73fe4a23138d8522ba7c17"}, - {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b287304f2b2220d51ccb51fd857761e78bcffbeabe7b0238f8dc36f2edfd9542"}, - {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e09000b2599e1836314430f81a3884c66a5cbabdff5d9f175b5d560d4de38d78"}, - {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9c80ce0001efa16066358d2dd77993adc25f5a6c61850e4ad096a2232930bce"}, - {file = "safetensors-0.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:413e1f6ac248f7d1b755199a06635e70c3515493d3b41ba46063dec33aa2ebb7"}, - {file = "safetensors-0.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3ac139377cfe71ba04573f1cda66e663b7c3e95be850e9e6c2dd4b5984bd513"}, - {file = "safetensors-0.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:04157d008385bea66d12fe90844a80d4a76dc25ec5230b5bd9a630496d1b7c03"}, - {file = "safetensors-0.4.1-cp311-none-win32.whl", hash = "sha256:5f25297148ec665f0deb8bd67e9564634d8d6841041ab5393ccfe203379ea88b"}, - {file = "safetensors-0.4.1-cp311-none-win_amd64.whl", hash = "sha256:b2f8877990a72ff595507b80f4b69036a9a1986a641f8681adf3425d97d3d2a5"}, - {file = "safetensors-0.4.1-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:eb2c1da1cc39509d1a55620a5f4d14f8911c47a89c926a96e6f4876e864375a3"}, - {file = "safetensors-0.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:303d2c0415cf15a28f8d7f17379ea3c34c2b466119118a34edd9965983a1a8a6"}, - {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb4cb3e37a9b961ddd68e873b29fe9ab4a081e3703412e34aedd2b7a8e9cafd9"}, - {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ae5497adc68669db2fed7cb2dad81e6a6106e79c9a132da3efdb6af1db1014fa"}, - {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3b30abd0cddfe959d1daedf92edcd1b445521ebf7ddefc20860ed01486b33c90"}, - {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d784a98c492c751f228a4a894c3b8a092ff08b24e73b5568938c28b8c0e8f8df"}, - {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e57a5ab08b0ec7a7caf30d2ac79bb30c89168431aca4f8854464bb9461686925"}, - {file = "safetensors-0.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:edcf3121890b5f0616aa5a54683b1a5d2332037b970e507d6bb7841a3a596556"}, - {file = "safetensors-0.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fdb58dee173ef33634c3016c459d671ca12d11e6acf9db008261cbe58107e579"}, - {file = "safetensors-0.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:780dc21eb3fd32ddd0e8c904bdb0290f2454f4ac21ae71e94f9ce72db1900a5a"}, - {file = "safetensors-0.4.1-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:48901bd540f8a3c1791314bc5c8a170927bf7f6acddb75bf0a263d081a3637d4"}, - {file = "safetensors-0.4.1-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:3b0b7b2d5976fbed8a05e2bbdce5816a59e6902e9e7c7e07dc723637ed539787"}, - {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f69903ff49cb30b9227fb5d029bea276ea20d04b06803877a420c5b1b74c689"}, - {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0ddd050e01f3e843aa8c1c27bf68675b8a08e385d0045487af4d70418c3cb356"}, - {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a82bc2bd7a9a0e08239bdd6d7774d64121f136add93dfa344a2f1a6d7ef35fa"}, - {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6ace9e66a40f98a216ad661245782483cf79cf56eb2b112650bb904b0baa9db5"}, - {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82cbb8f4d022f2e94498cbefca900698b8ded3d4f85212f47da614001ff06652"}, - {file = "safetensors-0.4.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:791edc10a3c359a2f5f52d5cddab0df8a45107d91027d86c3d44e57162e5d934"}, - {file = "safetensors-0.4.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:83c2cfbe8c6304f0891e7bb378d56f66d2148972eeb5f747cd8a2246886f0d8c"}, - {file = "safetensors-0.4.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:04dd14f53f5500eb4c4149674216ba1000670efbcf4b1b5c2643eb244e7882ea"}, - {file = "safetensors-0.4.1-cp37-none-win32.whl", hash = "sha256:d5b3defa74f3723a388bfde2f5d488742bc4879682bd93267c09a3bcdf8f869b"}, - {file = "safetensors-0.4.1-cp37-none-win_amd64.whl", hash = "sha256:25a043cbb59d4f75e9dd87fdf5c009dd8830105a2c57ace49b72167dd9808111"}, - {file = "safetensors-0.4.1-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:3f6a520af7f2717c5ecba112041f2c8af1ca6480b97bf957aba81ed9642e654c"}, - {file = "safetensors-0.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c3807ac3b16288dffebb3474b555b56fe466baa677dfc16290dcd02dca1ab228"}, - {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b58ba13a9e82b4bc3fc221914f6ef237fe6c2adb13cede3ace64d1aacf49610"}, - {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dac4bb42f8679aadc59bd91a4c5a1784a758ad49d0912995945cd674089f628e"}, - {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:911b48dc09e321a194def3a7431662ff4f03646832f3a8915bbf0f449b8a5fcb"}, - {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:82571d20288c975c1b30b08deb9b1c3550f36b31191e1e81fae87669a92217d0"}, - {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da52ee0dc8ba03348ffceab767bd8230842fdf78f8a996e2a16445747143a778"}, - {file = "safetensors-0.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2536b11ce665834201072e9397404170f93f3be10cca9995b909f023a04501ee"}, - {file = "safetensors-0.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:998fbac99ca956c3a09fe07cc0b35fac26a521fa8865a690686d889f0ff4e4a6"}, - {file = "safetensors-0.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:845be0aafabf2a60c2d482d4e93023fecffe5e5443d801d7a7741bae9de41233"}, - {file = "safetensors-0.4.1-cp38-none-win32.whl", hash = "sha256:ce7a28bc8af685a69d7e869d09d3e180a275e3281e29cf5f1c7319e231932cc7"}, - {file = "safetensors-0.4.1-cp38-none-win_amd64.whl", hash = "sha256:e056fb9e22d118cc546107f97dc28b449d88274207dd28872bd668c86216e4f6"}, - {file = "safetensors-0.4.1-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:bdc0d039e44a727824639824090bd8869535f729878fa248addd3dc01db30eae"}, - {file = "safetensors-0.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3c1b1d510c7aba71504ece87bf393ea82638df56303e371e5e2cf09d18977dd7"}, - {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0bd0afd95c1e497f520e680ea01e0397c0868a3a3030e128438cf6e9e3fcd671"}, - {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f603bdd8deac6726d39f41688ed353c532dd53935234405d79e9eb53f152fbfb"}, - {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d8a85e3e47e0d4eebfaf9a58b40aa94f977a56050cb5598ad5396a9ee7c087c6"}, - {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0ccb5aa0f3be2727117e5631200fbb3a5b3a2b3757545a92647d6dd8be6658f"}, - {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d784938534e255473155e4d9f276ee69eb85455b6af1292172c731409bf9adee"}, - {file = "safetensors-0.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a257de175c254d39ccd6a21341cd62eb7373b05c1e618a78096a56a857e0c316"}, - {file = "safetensors-0.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6fd80f7794554091836d4d613d33a7d006e2b8d6ba014d06f97cebdfda744f64"}, - {file = "safetensors-0.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:35803201d980efcf964b75a0a2aee97fe5e9ecc5f3ad676b38fafdfe98e0620d"}, - {file = "safetensors-0.4.1-cp39-none-win32.whl", hash = "sha256:7ff8a36e0396776d3ed9a106fc9a9d7c55d4439ca9a056a24bf66d343041d3e6"}, - {file = "safetensors-0.4.1-cp39-none-win_amd64.whl", hash = "sha256:bfa2e20342b81921b98edba52f8deb68843fa9c95250739a56b52ceda5ea5c61"}, - {file = "safetensors-0.4.1-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:ae2d5a31cfb8a973a318f7c4d2cffe0bd1fe753cdf7bb41a1939d45a0a06f964"}, - {file = "safetensors-0.4.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1a45dbf03e8334d3a5dc93687d98b6dc422f5d04c7d519dac09b84a3c87dd7c6"}, - {file = "safetensors-0.4.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2297b359d91126c0f9d4fd17bae3cfa2fe3a048a6971b8db07db746ad92f850c"}, - {file = "safetensors-0.4.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bda3d98e2bcece388232cfc551ebf063b55bdb98f65ab54df397da30efc7dcc5"}, - {file = "safetensors-0.4.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f8934bdfd202ebd0697040a3dff40dd77bc4c5bbf3527ede0532f5e7fb4d970f"}, - {file = "safetensors-0.4.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:42c3710cec7e5c764c7999697516370bee39067de0aa089b7e2cfb97ac8c6b20"}, - {file = "safetensors-0.4.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:53134226053e56bd56e73f7db42596e7908ed79f3c9a1016e4c1dade593ac8e5"}, - {file = "safetensors-0.4.1-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:257d59e40a1b367cb544122e7451243d65b33c3f34d822a347f4eea6fdf97fdf"}, - {file = "safetensors-0.4.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d54c2f1826e790d1eb2d2512bfd0ee443f0206b423d6f27095057c7f18a0687"}, - {file = "safetensors-0.4.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:645b3f1138fce6e818e79d4128afa28f0657430764cc045419c1d069ff93f732"}, - {file = "safetensors-0.4.1-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e9a7ffb1e551c6df51d267f5a751f042b183df22690f6feceac8d27364fd51d7"}, - {file = "safetensors-0.4.1-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:44e230fbbe120de564b64f63ef3a8e6ff02840fa02849d9c443d56252a1646d4"}, - {file = "safetensors-0.4.1-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:9d16b3b2fcc6fca012c74bd01b5619c655194d3e3c13e4d4d0e446eefa39a463"}, - {file = "safetensors-0.4.1-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:5d95ea4d8b32233910734a904123bdd3979c137c461b905a5ed32511defc075f"}, - {file = "safetensors-0.4.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:dab431699b5d45e0ca043bc580651ce9583dda594e62e245b7497adb32e99809"}, - {file = "safetensors-0.4.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16d8bbb7344e39cb9d4762e85c21df94ebeb03edac923dd94bb9ed8c10eac070"}, - {file = "safetensors-0.4.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1faf5111c66a6ba91f85dff2e36edaaf36e6966172703159daeef330de4ddc7b"}, - {file = "safetensors-0.4.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:660ca1d8bff6c7bc7c6b30b9b32df74ef3ab668f5df42cefd7588f0d40feadcb"}, - {file = "safetensors-0.4.1-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ae2f67f04ed0bb2e56fd380a8bd3eef03f609df53f88b6f5c7e89c08e52aae00"}, - {file = "safetensors-0.4.1-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:c8ed5d2c04cdc1afc6b3c28d59580448ac07732c50d94c15e14670f9c473a2ce"}, - {file = "safetensors-0.4.1-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:2b6a2814278b6660261aa9a9aae524616de9f1ec364e3716d219b6ed8f91801f"}, - {file = "safetensors-0.4.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:3cfd1ca35eacc635f0eaa894e5c5ed83ffebd0f95cac298fd430014fa7323631"}, - {file = "safetensors-0.4.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4177b456c6b0c722d82429127b5beebdaf07149d265748e97e0a34ff0b3694c8"}, - {file = "safetensors-0.4.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:313e8472197bde54e3ec54a62df184c414582979da8f3916981b6a7954910a1b"}, - {file = "safetensors-0.4.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:fdb4adb76e21bad318210310590de61c9f4adcef77ee49b4a234f9dc48867869"}, - {file = "safetensors-0.4.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:1d568628e9c43ca15eb96c217da73737c9ccb07520fafd8a1eba3f2750614105"}, - {file = "safetensors-0.4.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:573b6023a55a2f28085fc0a84e196c779b6cbef4d9e73acea14c8094fee7686f"}, - {file = "safetensors-0.4.1.tar.gz", hash = "sha256:2304658e6ada81a5223225b4efe84748e760c46079bffedf7e321763cafb36c9"}, + {file = "safetensors-0.4.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:69d8bb8384dc2cb5b72c36c4d6980771b293d1a1377b378763f5e37b6bb8d133"}, + {file = "safetensors-0.4.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3d420e19fcef96d0067f4de4699682b4bbd85fc8fea0bd45fcd961fdf3e8c82c"}, + {file = "safetensors-0.4.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9ca54742122fa3c4821754adb67318e1cd25c3a22bbf0c5520d5176e77a099ac"}, + {file = "safetensors-0.4.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8b47aa643afdfd66cf7ce4c184092ae734e15d10aba2c2948f24270211801c3c"}, + {file = "safetensors-0.4.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d88a16bbc330f27e7f2d4caaf6fb061ad0b8a756ecc4033260b0378e128ce8a2"}, + {file = "safetensors-0.4.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e9223b8ac21085db614a510eb3445e7083cae915a9202357555fa939695d4f57"}, + {file = "safetensors-0.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce6cb86133dc8930a7ab5e7438545a7f205f7a1cdd5aaf108c1d0da6bdcfbc2b"}, + {file = "safetensors-0.4.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b8a628e0ae2bbc334b62952c384aa5f41621d01850f8d67b04a96b9c39dd7326"}, + {file = "safetensors-0.4.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:88d6beb7f811a081e0e5f1d9669fdac816c45340c04b1eaf7ebfda0ce93ea403"}, + {file = "safetensors-0.4.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b57fc5b1b54cb12d8690a58a4cf4b7144730d4bde9d98aa0e1dab6295a1cd579"}, + {file = "safetensors-0.4.2-cp310-none-win32.whl", hash = "sha256:9d87a1c98803c16cf113b9ba03f07b2dce5e8eabfd1811a7f7323fcaa2a1bf47"}, + {file = "safetensors-0.4.2-cp310-none-win_amd64.whl", hash = "sha256:18930ec1d1ecb526d3d9835abc2489b8f1530877518f0c541e77ef0b7abcbd99"}, + {file = "safetensors-0.4.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:c5dd2ed788730ed56b415d1a11c62026b8cc8c573f55a2092afb3ab383e94fff"}, + {file = "safetensors-0.4.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cc41791b33efb9c83a59b731619f3d15f543dfe71f3a793cb8fbf9bd5d0d5d71"}, + {file = "safetensors-0.4.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c888bf71d5ca12a720f1ed87d407c4918afa022fb247a6546d8fac15b1f112b"}, + {file = "safetensors-0.4.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e6b2feb4b47226a16a792e6fac3f49442714884a3d4c1008569d5068a3941be9"}, + {file = "safetensors-0.4.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f41cc0ee4b838ae8f4d8364a1b162067693d11a3893f0863be8c228d40e4d0ee"}, + {file = "safetensors-0.4.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:51b7228e46c0a483c40ba4b9470dea00fb1ff8685026bb4766799000f6328ac2"}, + {file = "safetensors-0.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02697f8f2be8ca3c37a4958702dbdb1864447ef765e18b5328a1617022dcf164"}, + {file = "safetensors-0.4.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:27fd8f65cf7c80e4280cae1ee6bcd85c483882f6580821abe71ee1a0d3dcfca7"}, + {file = "safetensors-0.4.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c487b5f113b0924c9534a07dc034830fb4ef05ce9bb6d78cfe016a7dedfe281f"}, + {file = "safetensors-0.4.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:da7f6483f3fe67ff39b3a55552552c67930ea10a36e9f2539d36fc205273d767"}, + {file = "safetensors-0.4.2-cp311-none-win32.whl", hash = "sha256:52a7012f6cb9cb4a132760b6308daede18a9f5f8952ce08adc7c67a7d865c2d8"}, + {file = "safetensors-0.4.2-cp311-none-win_amd64.whl", hash = "sha256:4d1361a097ac430b310ce9eed8ed4746edee33ddafdfbb965debc8966fc34dc2"}, + {file = "safetensors-0.4.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:77af8aa0edcc2863760fd6febbfdb82e88fd75d0e60c1ce4ba57208ba5e4a89b"}, + {file = "safetensors-0.4.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:846666c1c5a8c8888d2dfda8d3921cb9cb8e2c5f78365be756c11021e75a0a2a"}, + {file = "safetensors-0.4.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f4bfc7ea19b446bfad41510d4b4c76101698c00caaa8a332c8edd8090a412ef"}, + {file = "safetensors-0.4.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:233436fd30f27ffeb3c3780d0b84f496518868445c7a8db003639a649cc98453"}, + {file = "safetensors-0.4.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7a09237a795d11cd11f9dae505d170a29b5616151db1e10c14f892b11caadc7d"}, + {file = "safetensors-0.4.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:de01c9a3a3b7b69627d624ff69d9f11d28ce9908eea2fb6245adafa4b1d43df6"}, + {file = "safetensors-0.4.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c1f25c5069ee42a5bcffdc66c300a407941edd73f3239e9fdefd26216407391"}, + {file = "safetensors-0.4.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7a73b3649456d09ca8506140d44484b63154a7378434cc1e8719f8056550b224"}, + {file = "safetensors-0.4.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e1625a8d07d046e968bd5c4961810aba1225984e4fb9243626f9d04a06ed3fee"}, + {file = "safetensors-0.4.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f74c86b25615cb24ad4cff765a2eefc09d71bf0fed97588cf585aad9c38fbb4"}, + {file = "safetensors-0.4.2-cp312-none-win32.whl", hash = "sha256:8523b9c5777d771bcde5c2389c03f1cdf7ebe8797432a1bd5e345efe25c55987"}, + {file = "safetensors-0.4.2-cp312-none-win_amd64.whl", hash = "sha256:dcff0243e1737a21f83d664c63fed89d1f532c23fc6830d0427279fabd789ccb"}, + {file = "safetensors-0.4.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:96ad3d7d472612e26cbe413922b4fb13933310f0511d346ea5cc9a1e856e52eb"}, + {file = "safetensors-0.4.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:88250922401b5ae4e37de929178caf46be47ed16c817b2237b81679bec07c120"}, + {file = "safetensors-0.4.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d40443554142fc0ab30652d5cc8554c4b7a613513bde00373e18afd5de8cbe4b"}, + {file = "safetensors-0.4.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:27f53f70106224d32d874aacecbeb4a6e4c5b16a1d2006d0e876d97229086d71"}, + {file = "safetensors-0.4.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cc068afe23734dfb26ce19db0a7877499ddf73b1d55ceb762417e8da4a1b05fb"}, + {file = "safetensors-0.4.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9be1918eb8d43a11a6f8806759fccfa0eeb0542b12924caba66af8a7800ad01a"}, + {file = "safetensors-0.4.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41911087d20a7bbd78cb4ad4f98aab0c431533107584df6635d8b54b99945573"}, + {file = "safetensors-0.4.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:50771c662aab909f31e94d048e76861fd027d66076ea773eef2e66c717766e24"}, + {file = "safetensors-0.4.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:13f2e57be007b7ea9329133d2399e6bdfcf1910f655440a4da17df3a45afcd30"}, + {file = "safetensors-0.4.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:c772147e6395bc829842e0a98e1b30c67fe25d816299c28196488511d5a5e951"}, + {file = "safetensors-0.4.2-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:36239a0060b537a3e8c473df78cffee14c3ec4f51d5f1a853af99371a2fb2a35"}, + {file = "safetensors-0.4.2-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:d0cbb7664fad2c307f95195f951b7059e95dc23e0e1822e5978c8b500098543c"}, + {file = "safetensors-0.4.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b3e55adb6bd9dc1c2a341e72f48f075953fa35d173dd8e29a95b3b02d0d1462"}, + {file = "safetensors-0.4.2-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:42f743b3cca863fba53ca57a193f510e5ec359b97f38c282437716b6768e4a25"}, + {file = "safetensors-0.4.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:04e6af4a6dbeb06c4e6e7d46cf9c716cbc4cc5ef62584fd8a7c0fe558562df45"}, + {file = "safetensors-0.4.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a492ba21b5c8f14ee5ec9b20f42ba969e53ca1f909a4d04aad736b66a341dcc2"}, + {file = "safetensors-0.4.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b25b8233a1a85dc67e39838951cfb01595d792f3b7b644add63edb652992e030"}, + {file = "safetensors-0.4.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:fd27e063fbdafe776f7b1714da59110e88f270e86db00788a8fd65f4eacfeba7"}, + {file = "safetensors-0.4.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:1b6fa399f251bbeb52029bf5a0ac2878d7705dd3612a2f8895b48e9c11f0367d"}, + {file = "safetensors-0.4.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:de642d46b459e4afd5c2020b26c0d6d869a171ea00411897d5776c127cac74f0"}, + {file = "safetensors-0.4.2-cp37-none-win32.whl", hash = "sha256:77b72d17754c93bb68f3598182f14d78776e0b9b31682ca5bb2c7c5bd9a75267"}, + {file = "safetensors-0.4.2-cp37-none-win_amd64.whl", hash = "sha256:d36ee3244d461cd655aeef493792c3bccf4875282f8407fd9af99e9a41cf2530"}, + {file = "safetensors-0.4.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:16b6b3884f7876c6b3b23a742428223a7170a5a9dac819d8c12a1569422c4b5a"}, + {file = "safetensors-0.4.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ee25d311493fbbe0be9d395faee46e9d79e8948f461e388ff39e59875ed9a350"}, + {file = "safetensors-0.4.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eed8097968585cd752a1171f86fce9aa1d89a29033e5cd8bec5a502e29f6b7af"}, + {file = "safetensors-0.4.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:880e6865cf72cb67f9ab8d04a3c4b49dd95ae92fb1583929ce65aed94e1f685f"}, + {file = "safetensors-0.4.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91290f83daf80ce6d1a7f629b244443c200060a80f908b29d879021409e5ea94"}, + {file = "safetensors-0.4.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3517d568486ab3508a7acc360b82d7a4a3e26b86efdf210a9ecd9d233c40708a"}, + {file = "safetensors-0.4.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1f43a77eb38540f782999e5dc5645164fe9027d3f0194f6c9a5126168017efa"}, + {file = "safetensors-0.4.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b684d9818aa5d63fddc65f7d0151968037d255d91adf74eba82125b41c680aaa"}, + {file = "safetensors-0.4.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ab1f5d84185f9fefaf21413efb764e4908057b8a9a0b987ede890c353490fd70"}, + {file = "safetensors-0.4.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2bd979642e6c3a517ef4b84ff36c2fee4015664fea05a61154fc565978347553"}, + {file = "safetensors-0.4.2-cp38-none-win32.whl", hash = "sha256:11be6e7afed29e5a5628f0aa6214e34bc194da73f558dc69fc7d56e07037422a"}, + {file = "safetensors-0.4.2-cp38-none-win_amd64.whl", hash = "sha256:2f7a6e5d29bd2cc340cffaa391fa437b1be9d21a2bd8b8724d2875d13a6ef2a9"}, + {file = "safetensors-0.4.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a5a921b4fe6925f9942adff3ebae8c16e0487908c54586a5a42f35b59fd69794"}, + {file = "safetensors-0.4.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b691727228c28f2d82d8a92b2bc26e7a1f129ee40b2f2a3185b5974e038ed47c"}, + {file = "safetensors-0.4.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91ca1056decc4e981248786e87b2a202d4841ee5f99d433f1adf3d44d4bcfa0e"}, + {file = "safetensors-0.4.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:55969fd2e6fdb38dc221b0ab380668c21b0efa12a7562db9924759faa3c51757"}, + {file = "safetensors-0.4.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6ae429bfaecc10ab5fe78c93009b3d1656c1581da560041e700eadb497dbe7a4"}, + {file = "safetensors-0.4.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4ff88f194fe4ac50b463a4a6f0c03af9ad72eb5d24ec6d6730af59522e37fedb"}, + {file = "safetensors-0.4.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a80cb48d0a447f8dd18e61813efa7d3f8f8d52edf0f05806abc0c59b83431f57"}, + {file = "safetensors-0.4.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b286fb7adfee70a4189898ac2342b8a67d5f493e6b21b0af89ca8eac1b967cbf"}, + {file = "safetensors-0.4.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0ceeff9ddbab4f78738489eb6682867ae946178776f33699737b2129b5394dc1"}, + {file = "safetensors-0.4.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a26fae748a7488cb3aac381eddfa818c42052c87b5e689fb4c6e82ed58cec209"}, + {file = "safetensors-0.4.2-cp39-none-win32.whl", hash = "sha256:039a42ab33c9d68b39706fd38f1922ace26866eff246bf20271edb619f5f848b"}, + {file = "safetensors-0.4.2-cp39-none-win_amd64.whl", hash = "sha256:b3a3e1f5b85859e398773f064943b62a4059f225008a2a8ee6add1edcf77cacf"}, + {file = "safetensors-0.4.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:4e70d442ad17e8b153ef9095bf48ea64f15a66bf26dc2b6ca94660c154edbc24"}, + {file = "safetensors-0.4.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:b90f1d9809caf4ff395951b4703295a68d12907f6945bbc3129e934ff8ae46f6"}, + {file = "safetensors-0.4.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c7ac9ad3728838006598e296b3ae9f27d80b489effd4685b92d97b3fc4c98f6"}, + {file = "safetensors-0.4.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de5730d77e6ff7f4c7039e20913661ad0ea2f86c09e71c039e73dfdd1f394f08"}, + {file = "safetensors-0.4.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:44feb8cb156d6803dcd19fc6b81b27235f29b877660605a6ac35e1da7d64f0e4"}, + {file = "safetensors-0.4.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:523a241c33e7c827ab9a3a23760d75c7d062f43dfe55b6b019409f89b0fb52d1"}, + {file = "safetensors-0.4.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:fb18300e8eb74291225214f26c9a8ae2110fd61a6c9b5a2ff4c4e0eb1bb9a998"}, + {file = "safetensors-0.4.2-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:fe5437ff9fb116e44f2ab558981249ae63f978392b4576e62fcfe167d353edbc"}, + {file = "safetensors-0.4.2-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9304a0934ced5a5d272f39de36291dc141dfc152d277f03fb4d65f2fb2ffa7c"}, + {file = "safetensors-0.4.2-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:160ba1b1e11cf874602c233ab80a14f588571d09556cbc3586900121d622b5ed"}, + {file = "safetensors-0.4.2-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04fcd6fcf7d9c13c7e5dc7e08de5e492ee4daa8f4ad74b4d8299d3eb0224292f"}, + {file = "safetensors-0.4.2-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:906d14c4a677d35834fb0f3a5455ef8305e1bba10a5e0f2e0f357b3d1ad989f2"}, + {file = "safetensors-0.4.2-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:df3fcdec0cd543084610d1f09c65cdb10fb3079f79bceddc092b0d187c6a265b"}, + {file = "safetensors-0.4.2-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5ca76f13fb1cef242ea3ad2cb37388e7d005994f42af8b44bee56ba48b2d45ce"}, + {file = "safetensors-0.4.2-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:278a1a3414c020785decdcd741c578725721274d2f9f787fcc930882e83b89cc"}, + {file = "safetensors-0.4.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05b5a461cc68ecd42d9d546e5e1268a39d8ede7934a68d1ce17c3c659cb829d6"}, + {file = "safetensors-0.4.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2341411412a41671d25e26bed59ec121e46bf4fadb8132895e610411c4b9681"}, + {file = "safetensors-0.4.2-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3497ac3895acf17c5f98197f1fa4769f09c5e7ede07fcb102f1c201e663e052c"}, + {file = "safetensors-0.4.2-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:01b5e71d3754d2201294f1eb7a6d59cce3a5702ff96d83d226571b2ca2183837"}, + {file = "safetensors-0.4.2-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:3627dbd1ea488dd8046a0491de5087f3c0d641e7acc80c0189a33c69398f1cd1"}, + {file = "safetensors-0.4.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:9d56f0ef53afad26ec54ceede78a43e9a23a076dadbbda7b44d304c591abf4c1"}, + {file = "safetensors-0.4.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:b259ca73d42daf658a1bda463f1f83885ae4d93a60869be80d7f7dfcc9d8bbb5"}, + {file = "safetensors-0.4.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ebc3cd401e4eb54e7c0a70346be565e81942d9a41fafd5f4bf7ab3a55d10378"}, + {file = "safetensors-0.4.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5bc384a0309b706aa0425c93abb0390508a61bf029ce99c7d9df4220f25871a5"}, + {file = "safetensors-0.4.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:af2d8f7235d8a08fbccfb8394387890e7fa38942b349a94e6eff13c52ac98087"}, + {file = "safetensors-0.4.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:0911315bbcc5289087d063c2c2c7ccd711ea97a7e557a7bce005ac2cf80146aa"}, + {file = "safetensors-0.4.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:1efe31673be91832d73439a2af426743e1395fc9ef7b081914e9e1d567bd7b5f"}, + {file = "safetensors-0.4.2.tar.gz", hash = "sha256:acc85dcb09ec5e8aa787f588d7ad4d55c103f31e4ff060e17d92cc0e8b8cac73"}, ] [package.extras] all = ["safetensors[jax]", "safetensors[numpy]", "safetensors[paddlepaddle]", "safetensors[pinned-tf]", "safetensors[quality]", "safetensors[testing]", "safetensors[torch]"] dev = ["safetensors[all]"] jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "safetensors[numpy]"] +mlx = ["mlx (>=0.0.9)"] numpy = ["numpy (>=1.21.6)"] paddlepaddle = ["paddlepaddle (>=2.4.1)", "safetensors[numpy]"] pinned-tf = ["safetensors[numpy]", "tensorflow (==2.11.0)"] @@ -3848,15 +3937,77 @@ nativelib = ["pyobjc-framework-Cocoa", "pywin32"] objc = ["pyobjc-framework-Cocoa"] win32 = ["pywin32"] +[[package]] +name = "sentencepiece" +version = "0.2.0" +description = "SentencePiece python wrapper" +optional = false +python-versions = "*" +files = [ + {file = "sentencepiece-0.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:188779e1298a1c8b8253c7d3ad729cb0a9891e5cef5e5d07ce4592c54869e227"}, + {file = "sentencepiece-0.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bed9cf85b296fa2b76fc2547b9cbb691a523864cebaee86304c43a7b4cb1b452"}, + {file = "sentencepiece-0.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d7b67e724bead13f18db6e1d10b6bbdc454af574d70efbb36f27d90387be1ca3"}, + {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2fde4b08cfe237be4484c6c7c2e2c75fb862cfeab6bd5449ce4caeafd97b767a"}, + {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c378492056202d1c48a4979650981635fd97875a00eabb1f00c6a236b013b5e"}, + {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1380ce6540a368de2ef6d7e6ba14ba8f3258df650d39ba7d833b79ee68a52040"}, + {file = "sentencepiece-0.2.0-cp310-cp310-win32.whl", hash = "sha256:a1151d6a6dd4b43e552394aed0edfe9292820272f0194bd56c7c1660a0c06c3d"}, + {file = "sentencepiece-0.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:d490142b0521ef22bc1085f061d922a2a6666175bb6b42e588ff95c0db6819b2"}, + {file = "sentencepiece-0.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:17982700c4f6dbb55fa3594f3d7e5dd1c8659a274af3738e33c987d2a27c9d5c"}, + {file = "sentencepiece-0.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7c867012c0e8bcd5bdad0f791609101cb5c66acb303ab3270218d6debc68a65e"}, + {file = "sentencepiece-0.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7fd6071249c74f779c5b27183295b9202f8dedb68034e716784364443879eaa6"}, + {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27f90c55a65013cbb8f4d7aab0599bf925cde4adc67ae43a0d323677b5a1c6cb"}, + {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b293734059ef656dcd65be62ff771507bea8fed0a711b6733976e1ed3add4553"}, + {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e58b47f933aca74c6a60a79dcb21d5b9e47416256c795c2d58d55cec27f9551d"}, + {file = "sentencepiece-0.2.0-cp311-cp311-win32.whl", hash = "sha256:c581258cf346b327c62c4f1cebd32691826306f6a41d8c4bec43b010dee08e75"}, + {file = "sentencepiece-0.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:0993dbc665f4113017892f1b87c3904a44d0640eda510abcacdfb07f74286d36"}, + {file = "sentencepiece-0.2.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea5f536e32ea8ec96086ee00d7a4a131ce583a1b18d130711707c10e69601cb2"}, + {file = "sentencepiece-0.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0cb51f53b6aae3c36bafe41e86167c71af8370a039f542c43b0cce5ef24a68c"}, + {file = "sentencepiece-0.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3212121805afc58d8b00ab4e7dd1f8f76c203ddb9dc94aa4079618a31cf5da0f"}, + {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a3149e3066c2a75e0d68a43eb632d7ae728c7925b517f4c05c40f6f7280ce08"}, + {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:632f3594d3e7ac8b367bca204cb3fd05a01d5b21455acd097ea4c0e30e2f63d7"}, + {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f295105c6bdbb05bd5e1b0cafbd78ff95036f5d3641e7949455a3f4e5e7c3109"}, + {file = "sentencepiece-0.2.0-cp312-cp312-win32.whl", hash = "sha256:fb89f811e5efd18bab141afc3fea3de141c3f69f3fe9e898f710ae7fe3aab251"}, + {file = "sentencepiece-0.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a673a72aab81fef5ebe755c6e0cc60087d1f3a4700835d40537183c1703a45f"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4547683f330289ec4f093027bfeb87f9ef023b2eb6f879fdc4a8187c7e0ffb90"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cd6175f7eaec7142d2bf6f6597ce7db4c9ac89acf93fcdb17410c3a8b781eeb"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:859ba1acde782609a0910a26a60e16c191a82bf39b5621107552c0cd79fad00f"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcbbef6cc277f8f18f36959e305f10b1c620442d75addc79c21d7073ae581b50"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-win32.whl", hash = "sha256:536b934e244829e3fe6c4f198652cd82da48adb9aa145c9f00889542726dee3d"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-win_amd64.whl", hash = "sha256:0a91aaa3c769b52440df56fafda683b3aa48e3f2169cf7ee5b8c8454a7f3ae9b"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:787e480ca4c1d08c9985a7eb1eae4345c107729c99e9b5a9a00f2575fc7d4b4b"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4d158189eb2ecffea3a51edf6d25e110b3678ec47f1a40f2d541eafbd8f6250"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1e5ca43013e8935f25457a4fca47e315780172c3e821b4b13a890668911c792"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7140d9e5a74a0908493bb4a13f1f16a401297bd755ada4c707e842fbf6f0f5bf"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-win32.whl", hash = "sha256:6cf333625234f247ab357b0bd9836638405ea9082e1543d5b8408f014979dcbf"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:ff88712338b01031910e8e61e7239aff3ce8869ee31a47df63cb38aadd591bea"}, + {file = "sentencepiece-0.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20813a68d4c221b1849c62c30e1281ea81687894d894b8d4a0f4677d9311e0f5"}, + {file = "sentencepiece-0.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:926ef920ae2e8182db31d3f5d081ada57804e3e1d3a8c4ef8b117f9d9fb5a945"}, + {file = "sentencepiece-0.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:89f65f69636b7e9c015b79dff9c9985a9bc7d19ded6f79ef9f1ec920fdd73ecf"}, + {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f67eae0dbe6f2d7d6ba50a354623d787c99965f068b81e145d53240198021b0"}, + {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:98501e075f35dd1a1d5a20f65be26839fcb1938752ec61539af008a5aa6f510b"}, + {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3d1d2cc4882e8d6a1adf9d5927d7716f80617fc693385661caff21888972269"}, + {file = "sentencepiece-0.2.0-cp38-cp38-win32.whl", hash = "sha256:b99a308a2e5e569031ab164b74e6fab0b6f37dfb493c32f7816225f4d411a6dd"}, + {file = "sentencepiece-0.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:cdb701eec783d3ec86b7cd4c763adad8eaf6b46db37ee1c36e5e6c44b3fe1b5f"}, + {file = "sentencepiece-0.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1e0f9c4d0a6b0af59b613175f019916e28ade076e21242fd5be24340d8a2f64a"}, + {file = "sentencepiece-0.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:298f21cc1366eb60311aedba3169d30f885c363ddbf44214b0a587d2908141ad"}, + {file = "sentencepiece-0.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3f1ec95aa1e5dab11f37ac7eff190493fd87770f7a8b81ebc9dd768d1a3c8704"}, + {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b06b70af54daa4b4904cbb90b4eb6d35c9f3252fdc86c9c32d5afd4d30118d8"}, + {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22e37bac44dd6603388cb598c64ff7a76e41ca774646f21c23aadfbf5a2228ab"}, + {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0461324897735512a32d222e3d886e24ad6a499761952b6bda2a9ee6e4313ea5"}, + {file = "sentencepiece-0.2.0-cp39-cp39-win32.whl", hash = "sha256:38aed822fb76435fa1f12185f10465a94ab9e51d5e8a9159e9a540ce926f0ffd"}, + {file = "sentencepiece-0.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:d8cf876516548b5a1d6ac4745d8b554f5c07891d55da557925e5c13ff0b4e6ad"}, + {file = "sentencepiece-0.2.0.tar.gz", hash = "sha256:a52c19171daaf2e697dc6cbe67684e0fa341b1248966f6aebb541de654d15843"}, +] + [[package]] name = "sentry-sdk" -version = "1.39.1" +version = "1.44.0" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = "*" files = [ - {file = "sentry-sdk-1.39.1.tar.gz", hash = "sha256:320a55cdf9da9097a0bead239c35b7e61f53660ef9878861824fd6d9b2eaf3b5"}, - {file = "sentry_sdk-1.39.1-py2.py3-none-any.whl", hash = "sha256:81b5b9ffdd1a374e9eb0c053b5d2012155db9cbe76393a8585677b753bd5fdc1"}, + {file = "sentry-sdk-1.44.0.tar.gz", hash = "sha256:f7125a9235795811962d52ff796dc032cd1d0dd98b59beaced8380371cd9c13c"}, + {file = "sentry_sdk-1.44.0-py2.py3-none-any.whl", hash = "sha256:eb65289da013ca92fad2694851ad2f086aa3825e808dc285bd7dcaf63602bb18"}, ] [package.dependencies] @@ -3870,6 +4021,7 @@ asyncpg = ["asyncpg (>=0.23)"] beam = ["apache-beam (>=2.12)"] bottle = ["bottle (>=0.12.13)"] celery = ["celery (>=3)"] +celery-redbeat = ["celery-redbeat (>=2)"] chalice = ["chalice (>=1.16.0)"] clickhouse-driver = ["clickhouse-driver (>=0.2.0)"] django = ["django (>=1.8)"] @@ -3880,6 +4032,7 @@ grpcio = ["grpcio (>=1.21.1)"] httpx = ["httpx (>=0.16.0)"] huey = ["huey (>=2)"] loguru = ["loguru (>=0.5)"] +openai = ["openai (>=1.0.0)", "tiktoken (>=0.3.0)"] opentelemetry = ["opentelemetry-distro (>=0.35b0)"] opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"] pure-eval = ["asttokens", "executing", "pure-eval"] @@ -3995,19 +4148,19 @@ test = ["pytest"] [[package]] name = "setuptools" -version = "69.0.3" +version = "69.2.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-69.0.3-py3-none-any.whl", hash = "sha256:385eb4edd9c9d5c17540511303e39a147ce2fc04bc55289c322b9e5904fe2c05"}, - {file = "setuptools-69.0.3.tar.gz", hash = "sha256:be1af57fc409f93647f2e8e4573a142ed38724b8cdd389706a867bb4efcf1e78"}, + {file = "setuptools-69.2.0-py3-none-any.whl", hash = "sha256:c21c49fb1042386df081cb5d86759792ab89efca84cf114889191cd09aacc80c"}, + {file = "setuptools-69.2.0.tar.gz", hash = "sha256:0ff4183f8f42cd8fa3acea16c45205521a4ef28f73c6391d8a25e92893134f2e"}, ] [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] -testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] [[package]] name = "six" @@ -4033,13 +4186,13 @@ files = [ [[package]] name = "sniffio" -version = "1.3.0" +version = "1.3.1" description = "Sniff out which async library your code is running under" optional = false python-versions = ">=3.7" files = [ - {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"}, - {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, + {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, + {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] [[package]] @@ -4302,13 +4455,13 @@ doc = ["reno", "sphinx", "tornado (>=4.5)"] [[package]] name = "terminado" -version = "0.18.0" +version = "0.18.1" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." optional = false python-versions = ">=3.8" files = [ - {file = "terminado-0.18.0-py3-none-any.whl", hash = "sha256:87b0d96642d0fe5f5abd7783857b9cab167f221a39ff98e3b9619a788a3c0f2e"}, - {file = "terminado-0.18.0.tar.gz", hash = "sha256:1ea08a89b835dd1b8c0c900d92848147cef2537243361b2e3f4dc15df9b6fded"}, + {file = "terminado-0.18.1-py3-none-any.whl", hash = "sha256:a4468e1b37bb318f8a86514f65814e1afc977cf29b3992a4500d9dd305dcceb0"}, + {file = "terminado-0.18.1.tar.gz", hash = "sha256:de09f2c4b85de4765f7714688fff57d3e75bad1f909b589fde880460c753fd2e"}, ] [package.dependencies] @@ -4341,109 +4494,121 @@ test = ["flake8", "isort", "pytest"] [[package]] name = "tokenizers" -version = "0.15.0" +version = "0.15.2" description = "" optional = false python-versions = ">=3.7" files = [ - {file = "tokenizers-0.15.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:cd3cd0299aaa312cd2988957598f80becd04d5a07338741eca076057a2b37d6e"}, - {file = "tokenizers-0.15.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a922c492c721744ee175f15b91704be2d305569d25f0547c77cd6c9f210f9dc"}, - {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:331dd786d02fc38698f835fff61c99480f98b73ce75a4c65bd110c9af5e4609a"}, - {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88dd0961c437d413ab027f8b115350c121d49902cfbadf08bb8f634b15fa1814"}, - {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6fdcc55339df7761cd52e1fbe8185d3b3963bc9e3f3545faa6c84f9e8818259a"}, - {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1480b0051d8ab5408e8e4db2dc832f7082ea24aa0722c427bde2418c6f3bd07"}, - {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9855e6c258918f9cf62792d4f6ddfa6c56dccd8c8118640f867f6393ecaf8bd7"}, - {file = "tokenizers-0.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de9529fe75efcd54ba8d516aa725e1851df9199f0669b665c55e90df08f5af86"}, - {file = "tokenizers-0.15.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:8edcc90a36eab0705fe9121d6c77c6e42eeef25c7399864fd57dfb27173060bf"}, - {file = "tokenizers-0.15.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ae17884aafb3e94f34fb7cfedc29054f5f54e142475ebf8a265a4e388fee3f8b"}, - {file = "tokenizers-0.15.0-cp310-none-win32.whl", hash = "sha256:9a3241acdc9b44cff6e95c4a55b9be943ef3658f8edb3686034d353734adba05"}, - {file = "tokenizers-0.15.0-cp310-none-win_amd64.whl", hash = "sha256:4b31807cb393d6ea31926b307911c89a1209d5e27629aa79553d1599c8ffdefe"}, - {file = "tokenizers-0.15.0-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:af7e9be8c05d30bb137b9fd20f9d99354816599e5fd3d58a4b1e28ba3b36171f"}, - {file = "tokenizers-0.15.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c3d7343fa562ea29661783344a2d83662db0d3d17a6fa6a403cac8e512d2d9fd"}, - {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:32371008788aeeb0309a9244809a23e4c0259625e6b74a103700f6421373f395"}, - {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca9db64c7c9954fbae698884c5bb089764edc549731e5f9b7fa1dd4e4d78d77f"}, - {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dbed5944c31195514669cf6381a0d8d47f164943000d10f93d6d02f0d45c25e0"}, - {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aab16c4a26d351d63e965b0c792f5da7227a37b69a6dc6d922ff70aa595b1b0c"}, - {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3c2b60b12fdd310bf85ce5d7d3f823456b9b65eed30f5438dd7761879c495983"}, - {file = "tokenizers-0.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0344d6602740e44054a9e5bbe9775a5e149c4dddaff15959bb07dcce95a5a859"}, - {file = "tokenizers-0.15.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4525f6997d81d9b6d9140088f4f5131f6627e4c960c2c87d0695ae7304233fc3"}, - {file = "tokenizers-0.15.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:65975094fef8cc68919644936764efd2ce98cf1bacbe8db2687155d2b0625bee"}, - {file = "tokenizers-0.15.0-cp311-none-win32.whl", hash = "sha256:ff5d2159c5d93015f5a4542aac6c315506df31853123aa39042672031768c301"}, - {file = "tokenizers-0.15.0-cp311-none-win_amd64.whl", hash = "sha256:2dd681b53cf615e60a31a115a3fda3980e543d25ca183797f797a6c3600788a3"}, - {file = "tokenizers-0.15.0-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:c9cce6ee149a3d703f86877bc2a6d997e34874b2d5a2d7839e36b2273f31d3d9"}, - {file = "tokenizers-0.15.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4a0a94bc3370e6f1cc8a07a8ae867ce13b7c1b4291432a773931a61f256d44ea"}, - {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:309cfcccfc7e502cb1f1de2c9c1c94680082a65bfd3a912d5a5b2c90c677eb60"}, - {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8413e994dd7d875ab13009127fc85633916c71213917daf64962bafd488f15dc"}, - {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d0ebf9430f901dbdc3dcb06b493ff24a3644c9f88c08e6a1d6d0ae2228b9b818"}, - {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:10361e9c7864b22dd791ec5126327f6c9292fb1d23481d4895780688d5e298ac"}, - {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:babe42635b8a604c594bdc56d205755f73414fce17ba8479d142a963a6c25cbc"}, - {file = "tokenizers-0.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3768829861e964c7a4556f5f23307fce6a23872c2ebf030eb9822dbbbf7e9b2a"}, - {file = "tokenizers-0.15.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9c91588a630adc88065e1c03ac6831e3e2112558869b9ebcb2b8afd8a14c944d"}, - {file = "tokenizers-0.15.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:77606994e793ca54ecf3a3619adc8a906a28ca223d9354b38df41cb8766a0ed6"}, - {file = "tokenizers-0.15.0-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:6fe143939f3b596681922b2df12a591a5b010e7dcfbee2202482cd0c1c2f2459"}, - {file = "tokenizers-0.15.0-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:b7bee0f1795e3e3561e9a557061b1539e5255b8221e3f928f58100282407e090"}, - {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5d37e7f4439b4c46192ab4f2ff38ab815e4420f153caa13dec9272ef14403d34"}, - {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caadf255cf7f951b38d10097836d1f3bcff4aeaaffadfdf748bab780bf5bff95"}, - {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:05accb9162bf711a941b1460b743d62fec61c160daf25e53c5eea52c74d77814"}, - {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:26a2ef890740127cb115ee5260878f4a677e36a12831795fd7e85887c53b430b"}, - {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e54c5f26df14913620046b33e822cb3bcd091a332a55230c0e63cc77135e2169"}, - {file = "tokenizers-0.15.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:669b8ed653a578bcff919566631156f5da3aab84c66f3c0b11a6281e8b4731c7"}, - {file = "tokenizers-0.15.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:0ea480d943297df26f06f508dab6e012b07f42bf3dffdd36e70799368a5f5229"}, - {file = "tokenizers-0.15.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:bc80a0a565ebfc7cd89de7dd581da8c2b3238addfca6280572d27d763f135f2f"}, - {file = "tokenizers-0.15.0-cp37-none-win32.whl", hash = "sha256:cdd945e678bbdf4517d5d8de66578a5030aeefecdb46f5320b034de9cad8d4dd"}, - {file = "tokenizers-0.15.0-cp37-none-win_amd64.whl", hash = "sha256:1ab96ab7dc706e002c32b2ea211a94c1c04b4f4de48354728c3a6e22401af322"}, - {file = "tokenizers-0.15.0-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:f21c9eb71c9a671e2a42f18b456a3d118e50c7f0fc4dd9fa8f4eb727fea529bf"}, - {file = "tokenizers-0.15.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2a5f4543a35889679fc3052086e69e81880b2a5a28ff2a52c5a604be94b77a3f"}, - {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f8aa81afec893e952bd39692b2d9ef60575ed8c86fce1fd876a06d2e73e82dca"}, - {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1574a5a4af22c3def93fe8fe4adcc90a39bf5797ed01686a4c46d1c3bc677d2f"}, - {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7c7982fd0ec9e9122d03b209dac48cebfea3de0479335100ef379a9a959b9a5a"}, - {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8d16b647032df2ce2c1f9097236e046ea9fedd969b25637b9d5d734d78aa53b"}, - {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b3cdf29e6f9653da330515dc8fa414be5a93aae79e57f8acc50d4028dd843edf"}, - {file = "tokenizers-0.15.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7286f3df10de840867372e3e64b99ef58c677210e3ceb653cd0e740a5c53fe78"}, - {file = "tokenizers-0.15.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:aabc83028baa5a36ce7a94e7659250f0309c47fa4a639e5c2c38e6d5ea0de564"}, - {file = "tokenizers-0.15.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:72f78b0e0e276b1fc14a672fa73f3acca034ba8db4e782124a2996734a9ba9cf"}, - {file = "tokenizers-0.15.0-cp38-none-win32.whl", hash = "sha256:9680b0ecc26e7e42f16680c1aa62e924d58d1c2dd992707081cc10a374896ea2"}, - {file = "tokenizers-0.15.0-cp38-none-win_amd64.whl", hash = "sha256:f17cbd88dab695911cbdd385a5a7e3709cc61dff982351f5d1b5939f074a2466"}, - {file = "tokenizers-0.15.0-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:3661862df7382c5eb23ac4fbf7c75e69b02dc4f5784e4c5a734db406b5b24596"}, - {file = "tokenizers-0.15.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c3045d191dad49647f5a5039738ecf1c77087945c7a295f7bcf051c37067e883"}, - {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a9fcaad9ab0801f14457d7c820d9f246b5ab590c407fc6b073819b1573097aa7"}, - {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a79f17027f24fe9485701c8dbb269b9c713954ec3bdc1e7075a66086c0c0cd3c"}, - {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:01a3aa332abc4bee7640563949fcfedca4de8f52691b3b70f2fc6ca71bfc0f4e"}, - {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:05b83896a893cdfedad8785250daa3ba9f0504848323471524d4783d7291661e"}, - {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cbbf2489fcf25d809731ba2744ff278dd07d9eb3f8b7482726bd6cae607073a4"}, - {file = "tokenizers-0.15.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab806ad521a5e9de38078b7add97589c313915f6f5fec6b2f9f289d14d607bd6"}, - {file = "tokenizers-0.15.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4a522612d5c88a41563e3463226af64e2fa00629f65cdcc501d1995dd25d23f5"}, - {file = "tokenizers-0.15.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e58a38c4e6075810bdfb861d9c005236a72a152ebc7005941cc90d1bbf16aca9"}, - {file = "tokenizers-0.15.0-cp39-none-win32.whl", hash = "sha256:b8034f1041fd2bd2b84ff9f4dc4ae2e1c3b71606820a9cd5c562ebd291a396d1"}, - {file = "tokenizers-0.15.0-cp39-none-win_amd64.whl", hash = "sha256:edde9aa964145d528d0e0dbf14f244b8a85ebf276fb76869bc02e2530fa37a96"}, - {file = "tokenizers-0.15.0-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:309445d10d442b7521b98083dc9f0b5df14eca69dbbfebeb98d781ee2cef5d30"}, - {file = "tokenizers-0.15.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d3125a6499226d4d48efc54f7498886b94c418e93a205b673bc59364eecf0804"}, - {file = "tokenizers-0.15.0-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:ed56ddf0d54877bb9c6d885177db79b41576e61b5ef6defeb579dcb803c04ad5"}, - {file = "tokenizers-0.15.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b22cd714706cc5b18992a232b023f736e539495f5cc61d2d28d176e55046f6c"}, - {file = "tokenizers-0.15.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fac2719b1e9bc8e8e7f6599b99d0a8e24f33d023eb8ef644c0366a596f0aa926"}, - {file = "tokenizers-0.15.0-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:85ddae17570ec7e5bfaf51ffa78d044f444a8693e1316e1087ee6150596897ee"}, - {file = "tokenizers-0.15.0-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:76f1bed992e396bf6f83e3df97b64ff47885e45e8365f8983afed8556a0bc51f"}, - {file = "tokenizers-0.15.0-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:3bb0f4df6dce41a1c7482087b60d18c372ef4463cb99aa8195100fcd41e0fd64"}, - {file = "tokenizers-0.15.0-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:22c27672c27a059a5f39ff4e49feed8c7f2e1525577c8a7e3978bd428eb5869d"}, - {file = "tokenizers-0.15.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78104f5d035c9991f92831fc0efe9e64a05d4032194f2a69f67aaa05a4d75bbb"}, - {file = "tokenizers-0.15.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a40b73dc19d82c3e3ffb40abdaacca8fbc95eeb26c66b7f9f860aebc07a73998"}, - {file = "tokenizers-0.15.0-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d801d1368188c74552cd779b1286e67cb9fd96f4c57a9f9a2a09b6def9e1ab37"}, - {file = "tokenizers-0.15.0-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:82641ffb13a4da1293fcc9f437d457647e60ed0385a9216cd135953778b3f0a1"}, - {file = "tokenizers-0.15.0-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:160f9d1810f2c18fffa94aa98bf17632f6bd2dabc67fcb01a698ca80c37d52ee"}, - {file = "tokenizers-0.15.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:8d7d6eea831ed435fdeeb9bcd26476226401d7309d115a710c65da4088841948"}, - {file = "tokenizers-0.15.0-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f6456bec6c557d63d8ec0023758c32f589e1889ed03c055702e84ce275488bed"}, - {file = "tokenizers-0.15.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1eef39a502fad3bf104b9e1906b4fb0cee20e44e755e51df9a98f8922c3bf6d4"}, - {file = "tokenizers-0.15.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1e4664c5b797e093c19b794bbecc19d2367e782b4a577d8b7c1821db5dc150d"}, - {file = "tokenizers-0.15.0-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ca003fb5f3995ff5cf676db6681b8ea5d54d3b30bea36af1120e78ee1a4a4cdf"}, - {file = "tokenizers-0.15.0-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:7f17363141eb0c53752c89e10650b85ef059a52765d0802ba9613dbd2d21d425"}, - {file = "tokenizers-0.15.0-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:8a765db05581c7d7e1280170f2888cda351760d196cc059c37ea96f121125799"}, - {file = "tokenizers-0.15.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:2a0dd641a72604486cd7302dd8f87a12c8a9b45e1755e47d2682733f097c1af5"}, - {file = "tokenizers-0.15.0-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0a1a3c973e4dc97797fc19e9f11546c95278ffc55c4492acb742f69e035490bc"}, - {file = "tokenizers-0.15.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4fab75642aae4e604e729d6f78e0addb9d7e7d49e28c8f4d16b24da278e5263"}, - {file = "tokenizers-0.15.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65f80be77f6327a86d8fd35a4467adcfe6174c159b4ab52a1a8dd4c6f2d7d9e1"}, - {file = "tokenizers-0.15.0-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:a8da7533dbe66b88afd430c56a2f2ce1fd82e2681868f857da38eeb3191d7498"}, - {file = "tokenizers-0.15.0-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:fa8eb4584fc6cbe6a84d7a7864be3ed28e23e9fd2146aa8ef1814d579df91958"}, - {file = "tokenizers-0.15.0.tar.gz", hash = "sha256:10c7e6e7b4cabd757da59e93f5f8d1126291d16f8b54f28510825ef56a3e5d0e"}, + {file = "tokenizers-0.15.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:52f6130c9cbf70544287575a985bf44ae1bda2da7e8c24e97716080593638012"}, + {file = "tokenizers-0.15.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:054c1cc9c6d68f7ffa4e810b3d5131e0ba511b6e4be34157aa08ee54c2f8d9ee"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a9b9b070fdad06e347563b88c278995735292ded1132f8657084989a4c84a6d5"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea621a7eef4b70e1f7a4e84dd989ae3f0eeb50fc8690254eacc08acb623e82f1"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cf7fd9a5141634fa3aa8d6b7be362e6ae1b4cda60da81388fa533e0b552c98fd"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:44f2a832cd0825295f7179eaf173381dc45230f9227ec4b44378322d900447c9"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8b9ec69247a23747669ec4b0ca10f8e3dfb3545d550258129bd62291aabe8605"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40b6a4c78da863ff26dbd5ad9a8ecc33d8a8d97b535172601cf00aee9d7ce9ce"}, + {file = "tokenizers-0.15.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5ab2a4d21dcf76af60e05af8063138849eb1d6553a0d059f6534357bce8ba364"}, + {file = "tokenizers-0.15.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a47acfac7e511f6bbfcf2d3fb8c26979c780a91e06fb5b9a43831b2c0153d024"}, + {file = "tokenizers-0.15.2-cp310-none-win32.whl", hash = "sha256:064ff87bb6acdbd693666de9a4b692add41308a2c0ec0770d6385737117215f2"}, + {file = "tokenizers-0.15.2-cp310-none-win_amd64.whl", hash = "sha256:3b919afe4df7eb6ac7cafd2bd14fb507d3f408db7a68c43117f579c984a73843"}, + {file = "tokenizers-0.15.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:89cd1cb93e4b12ff39bb2d626ad77e35209de9309a71e4d3d4672667b4b256e7"}, + {file = "tokenizers-0.15.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cfed5c64e5be23d7ee0f0e98081a25c2a46b0b77ce99a4f0605b1ec43dd481fa"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a907d76dcfda37023ba203ab4ceeb21bc5683436ebefbd895a0841fd52f6f6f2"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20ea60479de6fc7b8ae756b4b097572372d7e4032e2521c1bbf3d90c90a99ff0"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:48e2b9335be2bc0171df9281385c2ed06a15f5cf121c44094338306ab7b33f2c"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:112a1dd436d2cc06e6ffdc0b06d55ac019a35a63afd26475205cb4b1bf0bfbff"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4620cca5c2817177ee8706f860364cc3a8845bc1e291aaf661fb899e5d1c45b0"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ccd73a82751c523b3fc31ff8194702e4af4db21dc20e55b30ecc2079c5d43cb7"}, + {file = "tokenizers-0.15.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:107089f135b4ae7817affe6264f8c7a5c5b4fd9a90f9439ed495f54fcea56fb4"}, + {file = "tokenizers-0.15.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0ff110ecc57b7aa4a594396525a3451ad70988e517237fe91c540997c4e50e29"}, + {file = "tokenizers-0.15.2-cp311-none-win32.whl", hash = "sha256:6d76f00f5c32da36c61f41c58346a4fa7f0a61be02f4301fd30ad59834977cc3"}, + {file = "tokenizers-0.15.2-cp311-none-win_amd64.whl", hash = "sha256:cc90102ed17271cf0a1262babe5939e0134b3890345d11a19c3145184b706055"}, + {file = "tokenizers-0.15.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f86593c18d2e6248e72fb91c77d413a815153b8ea4e31f7cd443bdf28e467670"}, + {file = "tokenizers-0.15.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0774bccc6608eca23eb9d620196687c8b2360624619623cf4ba9dc9bd53e8b51"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:d0222c5b7c9b26c0b4822a82f6a7011de0a9d3060e1da176f66274b70f846b98"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3835738be1de66624fff2f4f6f6684775da4e9c00bde053be7564cbf3545cc66"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0143e7d9dcd811855c1ce1ab9bf5d96d29bf5e528fd6c7824d0465741e8c10fd"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db35825f6d54215f6b6009a7ff3eedee0848c99a6271c870d2826fbbedf31a38"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3f5e64b0389a2be47091d8cc53c87859783b837ea1a06edd9d8e04004df55a5c"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e0480c452217edd35eca56fafe2029fb4d368b7c0475f8dfa3c5c9c400a7456"}, + {file = "tokenizers-0.15.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a33ab881c8fe70474980577e033d0bc9a27b7ab8272896e500708b212995d834"}, + {file = "tokenizers-0.15.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a308a607ca9de2c64c1b9ba79ec9a403969715a1b8ba5f998a676826f1a7039d"}, + {file = "tokenizers-0.15.2-cp312-none-win32.whl", hash = "sha256:b8fcfa81bcb9447df582c5bc96a031e6df4da2a774b8080d4f02c0c16b42be0b"}, + {file = "tokenizers-0.15.2-cp312-none-win_amd64.whl", hash = "sha256:38d7ab43c6825abfc0b661d95f39c7f8af2449364f01d331f3b51c94dcff7221"}, + {file = "tokenizers-0.15.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:38bfb0204ff3246ca4d5e726e8cc8403bfc931090151e6eede54d0e0cf162ef0"}, + {file = "tokenizers-0.15.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:9c861d35e8286a53e06e9e28d030b5a05bcbf5ac9d7229e561e53c352a85b1fc"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:936bf3842db5b2048eaa53dade907b1160f318e7c90c74bfab86f1e47720bdd6"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:620beacc3373277700d0e27718aa8b25f7b383eb8001fba94ee00aeea1459d89"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2735ecbbf37e52db4ea970e539fd2d450d213517b77745114f92867f3fc246eb"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:473c83c5e2359bb81b0b6fde870b41b2764fcdd36d997485e07e72cc3a62264a"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:968fa1fb3c27398b28a4eca1cbd1e19355c4d3a6007f7398d48826bbe3a0f728"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:865c60ae6eaebdde7da66191ee9b7db52e542ed8ee9d2c653b6d190a9351b980"}, + {file = "tokenizers-0.15.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7c0d8b52664ab2d4a8d6686eb5effc68b78608a9008f086a122a7b2996befbab"}, + {file = "tokenizers-0.15.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:f33dfbdec3784093a9aebb3680d1f91336c56d86cc70ddf88708251da1fe9064"}, + {file = "tokenizers-0.15.2-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:d44ba80988ff9424e33e0a49445072ac7029d8c0e1601ad25a0ca5f41ed0c1d6"}, + {file = "tokenizers-0.15.2-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:dce74266919b892f82b1b86025a613956ea0ea62a4843d4c4237be2c5498ed3a"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0ef06b9707baeb98b316577acb04f4852239d856b93e9ec3a299622f6084e4be"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c73e2e74bbb07910da0d37c326869f34113137b23eadad3fc00856e6b3d9930c"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4eeb12daf02a59e29f578a865f55d87cd103ce62bd8a3a5874f8fdeaa82e336b"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ba9f6895af58487ca4f54e8a664a322f16c26bbb442effd01087eba391a719e"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ccec77aa7150e38eec6878a493bf8c263ff1fa8a62404e16c6203c64c1f16a26"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3f40604f5042ff210ba82743dda2b6aa3e55aa12df4e9f2378ee01a17e2855e"}, + {file = "tokenizers-0.15.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5645938a42d78c4885086767c70923abad047163d809c16da75d6b290cb30bbe"}, + {file = "tokenizers-0.15.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:05a77cbfebe28a61ab5c3891f9939cc24798b63fa236d84e5f29f3a85a200c00"}, + {file = "tokenizers-0.15.2-cp37-none-win32.whl", hash = "sha256:361abdc068e8afe9c5b818769a48624687fb6aaed49636ee39bec4e95e1a215b"}, + {file = "tokenizers-0.15.2-cp37-none-win_amd64.whl", hash = "sha256:7ef789f83eb0f9baeb4d09a86cd639c0a5518528f9992f38b28e819df397eb06"}, + {file = "tokenizers-0.15.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:4fe1f74a902bee74a3b25aff180fbfbf4f8b444ab37c4d496af7afd13a784ed2"}, + {file = "tokenizers-0.15.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4c4b89038a684f40a6b15d6b09f49650ac64d951ad0f2a3ea9169687bbf2a8ba"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:d05a1b06f986d41aed5f2de464c003004b2df8aaf66f2b7628254bcbfb72a438"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:508711a108684111ec8af89d3a9e9e08755247eda27d0ba5e3c50e9da1600f6d"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:daa348f02d15160cb35439098ac96e3a53bacf35885072611cd9e5be7d333daa"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:494fdbe5932d3416de2a85fc2470b797e6f3226c12845cadf054dd906afd0442"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c2d60f5246f4da9373f75ff18d64c69cbf60c3bca597290cea01059c336d2470"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93268e788825f52de4c7bdcb6ebc1fcd4a5442c02e730faa9b6b08f23ead0e24"}, + {file = "tokenizers-0.15.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6fc7083ab404019fc9acafe78662c192673c1e696bd598d16dc005bd663a5cf9"}, + {file = "tokenizers-0.15.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:41e39b41e5531d6b2122a77532dbea60e171ef87a3820b5a3888daa847df4153"}, + {file = "tokenizers-0.15.2-cp38-none-win32.whl", hash = "sha256:06cd0487b1cbfabefb2cc52fbd6b1f8d4c37799bd6c6e1641281adaa6b2504a7"}, + {file = "tokenizers-0.15.2-cp38-none-win_amd64.whl", hash = "sha256:5179c271aa5de9c71712e31cb5a79e436ecd0d7532a408fa42a8dbfa4bc23fd9"}, + {file = "tokenizers-0.15.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:82f8652a74cc107052328b87ea8b34291c0f55b96d8fb261b3880216a9f9e48e"}, + {file = "tokenizers-0.15.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:02458bee6f5f3139f1ebbb6d042b283af712c0981f5bc50edf771d6b762d5e4f"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:c9a09cd26cca2e1c349f91aa665309ddb48d71636370749414fbf67bc83c5343"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:158be8ea8554e5ed69acc1ce3fbb23a06060bd4bbb09029431ad6b9a466a7121"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ddba9a2b0c8c81633eca0bb2e1aa5b3a15362b1277f1ae64176d0f6eba78ab1"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3ef5dd1d39797044642dbe53eb2bc56435308432e9c7907728da74c69ee2adca"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:454c203164e07a860dbeb3b1f4a733be52b0edbb4dd2e5bd75023ffa8b49403a"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cf6b7f1d4dc59af960e6ffdc4faffe6460bbfa8dce27a58bf75755ffdb2526d"}, + {file = "tokenizers-0.15.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2ef09bbc16519f6c25d0c7fc0c6a33a6f62923e263c9d7cca4e58b8c61572afb"}, + {file = "tokenizers-0.15.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c9a2ebdd2ad4ec7a68e7615086e633857c85e2f18025bd05d2a4399e6c5f7169"}, + {file = "tokenizers-0.15.2-cp39-none-win32.whl", hash = "sha256:918fbb0eab96fe08e72a8c2b5461e9cce95585d82a58688e7f01c2bd546c79d0"}, + {file = "tokenizers-0.15.2-cp39-none-win_amd64.whl", hash = "sha256:524e60da0135e106b254bd71f0659be9f89d83f006ea9093ce4d1fab498c6d0d"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6a9b648a58281c4672212fab04e60648fde574877d0139cd4b4f93fe28ca8944"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:7c7d18b733be6bbca8a55084027f7be428c947ddf871c500ee603e375013ffba"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:13ca3611de8d9ddfbc4dc39ef54ab1d2d4aaa114ac8727dfdc6a6ec4be017378"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:237d1bf3361cf2e6463e6c140628e6406766e8b27274f5fcc62c747ae3c6f094"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67a0fe1e49e60c664915e9fb6b0cb19bac082ab1f309188230e4b2920230edb3"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:4e022fe65e99230b8fd89ebdfea138c24421f91c1a4f4781a8f5016fd5cdfb4d"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:d857be2df69763362ac699f8b251a8cd3fac9d21893de129bc788f8baaef2693"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:708bb3e4283177236309e698da5fcd0879ce8fd37457d7c266d16b550bcbbd18"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:64c35e09e9899b72a76e762f9854e8750213f67567787d45f37ce06daf57ca78"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1257f4394be0d3b00de8c9e840ca5601d0a4a8438361ce9c2b05c7d25f6057b"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02272fe48280e0293a04245ca5d919b2c94a48b408b55e858feae9618138aeda"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dc3ad9ebc76eabe8b1d7c04d38be884b8f9d60c0cdc09b0aa4e3bcf746de0388"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:32e16bdeffa7c4f46bf2152172ca511808b952701d13e7c18833c0b73cb5c23f"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:fb16ba563d59003028b678d2361a27f7e4ae0ab29c7a80690efa20d829c81fdb"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:2277c36d2d6cdb7876c274547921a42425b6810d38354327dd65a8009acf870c"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:1cf75d32e8d250781940d07f7eece253f2fe9ecdb1dc7ba6e3833fa17b82fcbc"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1b3b31884dc8e9b21508bb76da80ebf7308fdb947a17affce815665d5c4d028"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b10122d8d8e30afb43bb1fe21a3619f62c3e2574bff2699cf8af8b0b6c5dc4a3"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d88b96ff0fe8e91f6ef01ba50b0d71db5017fa4e3b1d99681cec89a85faf7bf7"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:37aaec5a52e959892870a7c47cef80c53797c0db9149d458460f4f31e2fb250e"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:e2ea752f2b0fe96eb6e2f3adbbf4d72aaa1272079b0dfa1145507bd6a5d537e6"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:4b19a808d8799fda23504a5cd31d2f58e6f52f140380082b352f877017d6342b"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:64c86e5e068ac8b19204419ed8ca90f9d25db20578f5881e337d203b314f4104"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de19c4dc503c612847edf833c82e9f73cd79926a384af9d801dcf93f110cea4e"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea09acd2fe3324174063d61ad620dec3bcf042b495515f27f638270a7d466e8b"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:cf27fd43472e07b57cf420eee1e814549203d56de00b5af8659cb99885472f1f"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:7ca22bd897537a0080521445d91a58886c8c04084a6a19e6c78c586e0cfa92a5"}, + {file = "tokenizers-0.15.2.tar.gz", hash = "sha256:e6e9c6e019dd5484be5beafc775ae6c925f4c69a3487040ed09b45e13df2cb91"}, ] [package.dependencies] @@ -4467,42 +4632,47 @@ files = [ [[package]] name = "tomlkit" -version = "0.12.3" +version = "0.12.4" description = "Style preserving TOML library" optional = false python-versions = ">=3.7" files = [ - {file = "tomlkit-0.12.3-py3-none-any.whl", hash = "sha256:b0a645a9156dc7cb5d3a1f0d4bab66db287fcb8e0430bdd4664a095ea16414ba"}, - {file = "tomlkit-0.12.3.tar.gz", hash = "sha256:75baf5012d06501f07bee5bf8e801b9f343e7aac5a92581f20f80ce632e6b5a4"}, + {file = "tomlkit-0.12.4-py3-none-any.whl", hash = "sha256:5cd82d48a3dd89dee1f9d64420aa20ae65cfbd00668d6f094d7578a78efbb77b"}, + {file = "tomlkit-0.12.4.tar.gz", hash = "sha256:7ca1cfc12232806517a8515047ba66a19369e71edf2439d0f5824f91032b6cc3"}, ] [[package]] name = "torch" -version = "2.1.2" +version = "2.2.2" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.1.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:3a871edd6c02dae77ad810335c0833391c1a4ce49af21ea8cf0f6a5d2096eea8"}, - {file = "torch-2.1.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:bef6996c27d8f6e92ea4e13a772d89611da0e103b48790de78131e308cf73076"}, - {file = "torch-2.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:0e13034fd5fb323cbbc29e56d0637a3791e50dd589616f40c79adfa36a5a35a1"}, - {file = "torch-2.1.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:d9b535cad0df3d13997dbe8bd68ac33e0e3ae5377639c9881948e40794a61403"}, - {file = "torch-2.1.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:f9a55d55af02826ebfbadf4e9b682f0f27766bc33df8236b48d28d705587868f"}, - {file = "torch-2.1.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:a6ebbe517097ef289cc7952783588c72de071d4b15ce0f8b285093f0916b1162"}, - {file = "torch-2.1.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:8f32ce591616a30304f37a7d5ea80b69ca9e1b94bba7f308184bf616fdaea155"}, - {file = "torch-2.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:e0ee6cf90c8970e05760f898d58f9ac65821c37ffe8b04269ec787aa70962b69"}, - {file = "torch-2.1.2-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:76d37967c31c99548ad2c4d3f2cf191db48476f2e69b35a0937137116da356a1"}, - {file = "torch-2.1.2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:e2d83f07b4aac983453ea5bf8f9aa9dacf2278a8d31247f5d9037f37befc60e4"}, - {file = "torch-2.1.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:f41fe0c7ecbf903a568c73486139a75cfab287a0f6c17ed0698fdea7a1e8641d"}, - {file = "torch-2.1.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e3225f47d50bb66f756fe9196a768055d1c26b02154eb1f770ce47a2578d3aa7"}, - {file = "torch-2.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:33d59cd03cb60106857f6c26b36457793637512998666ee3ce17311f217afe2b"}, - {file = "torch-2.1.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:8e221deccd0def6c2badff6be403e0c53491805ed9915e2c029adbcdb87ab6b5"}, - {file = "torch-2.1.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:05b18594f60a911a0c4f023f38a8bda77131fba5fd741bda626e97dcf5a3dd0a"}, - {file = "torch-2.1.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9ca96253b761e9aaf8e06fb30a66ee301aecbf15bb5a303097de1969077620b6"}, - {file = "torch-2.1.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d93ba70f67b08c2ae5598ee711cbc546a1bc8102cef938904b8c85c2089a51a0"}, - {file = "torch-2.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:255b50bc0608db177e6a3cc118961d77de7e5105f07816585fa6f191f33a9ff3"}, - {file = "torch-2.1.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6984cd5057c0c977b3c9757254e989d3f1124f4ce9d07caa6cb637783c71d42a"}, - {file = "torch-2.1.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:bc195d7927feabc0eb7c110e457c955ed2ab616f3c7c28439dd4188cf589699f"}, + {file = "torch-2.2.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:bc889d311a855dd2dfd164daf8cc903a6b7273a747189cebafdd89106e4ad585"}, + {file = "torch-2.2.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:15dffa4cc3261fa73d02f0ed25f5fa49ecc9e12bf1ae0a4c1e7a88bbfaad9030"}, + {file = "torch-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:11e8fe261233aeabd67696d6b993eeb0896faa175c6b41b9a6c9f0334bdad1c5"}, + {file = "torch-2.2.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:b2e2200b245bd9f263a0d41b6a2dab69c4aca635a01b30cca78064b0ef5b109e"}, + {file = "torch-2.2.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:877b3e6593b5e00b35bbe111b7057464e76a7dd186a287280d941b564b0563c2"}, + {file = "torch-2.2.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:ad4c03b786e074f46606f4151c0a1e3740268bcf29fbd2fdf6666d66341c1dcb"}, + {file = "torch-2.2.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:32827fa1fbe5da8851686256b4cd94cc7b11be962862c2293811c94eea9457bf"}, + {file = "torch-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:f9ef0a648310435511e76905f9b89612e45ef2c8b023bee294f5e6f7e73a3e7c"}, + {file = "torch-2.2.2-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:95b9b44f3bcebd8b6cd8d37ec802048c872d9c567ba52c894bba90863a439059"}, + {file = "torch-2.2.2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:49aa4126ede714c5aeef7ae92969b4b0bbe67f19665106463c39f22e0a1860d1"}, + {file = "torch-2.2.2-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:cf12cdb66c9c940227ad647bc9cf5dba7e8640772ae10dfe7569a0c1e2a28aca"}, + {file = "torch-2.2.2-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:89ddac2a8c1fb6569b90890955de0c34e1724f87431cacff4c1979b5f769203c"}, + {file = "torch-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:451331406b760f4b1ab298ddd536486ab3cfb1312614cfe0532133535be60bea"}, + {file = "torch-2.2.2-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:eb4d6e9d3663e26cd27dc3ad266b34445a16b54908e74725adb241aa56987533"}, + {file = "torch-2.2.2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:bf9558da7d2bf7463390b3b2a61a6a3dbb0b45b161ee1dd5ec640bf579d479fc"}, + {file = "torch-2.2.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cd2bf7697c9e95fb5d97cc1d525486d8cf11a084c6af1345c2c2c22a6b0029d0"}, + {file = "torch-2.2.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:b421448d194496e1114d87a8b8d6506bce949544e513742b097e2ab8f7efef32"}, + {file = "torch-2.2.2-cp38-cp38-win_amd64.whl", hash = "sha256:3dbcd563a9b792161640c0cffe17e3270d85e8f4243b1f1ed19cca43d28d235b"}, + {file = "torch-2.2.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:31f4310210e7dda49f1fb52b0ec9e59382cfcb938693f6d5378f25b43d7c1d29"}, + {file = "torch-2.2.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:c795feb7e8ce2e0ef63f75f8e1ab52e7fd5e1a4d7d0c31367ade1e3de35c9e95"}, + {file = "torch-2.2.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:a6e5770d68158d07456bfcb5318b173886f579fdfbf747543901ce718ea94782"}, + {file = "torch-2.2.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:67dcd726edff108e2cd6c51ff0e416fd260c869904de95750e80051358680d24"}, + {file = "torch-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:539d5ef6c4ce15bd3bd47a7b4a6e7c10d49d4d21c0baaa87c7d2ef8698632dfb"}, + {file = "torch-2.2.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:dff696de90d6f6d1e8200e9892861fd4677306d0ef604cb18f2134186f719f82"}, + {file = "torch-2.2.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:3a4dd910663fd7a124c056c878a52c2b0be4a5a424188058fe97109d4436ee42"}, ] [package.dependencies] @@ -4519,15 +4689,15 @@ nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linu nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nccl-cu12 = {version = "2.18.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu12 = {version = "2.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} sympy = "*" -triton = {version = "2.1.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -typing-extensions = "*" +triton = {version = "2.2.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} +typing-extensions = ">=4.8.0" [package.extras] -dynamo = ["jinja2"] opt-einsum = ["opt-einsum (>=3.3)"] +optree = ["optree (>=0.9.1)"] [[package]] name = "tornado" @@ -4551,13 +4721,13 @@ files = [ [[package]] name = "tqdm" -version = "4.66.1" +version = "4.66.2" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" files = [ - {file = "tqdm-4.66.1-py3-none-any.whl", hash = "sha256:d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"}, - {file = "tqdm-4.66.1.tar.gz", hash = "sha256:d88e651f9db8d8551a62556d3cff9e3034274ca5d66e93197cf2490e2dcb69c7"}, + {file = "tqdm-4.66.2-py3-none-any.whl", hash = "sha256:1ee4f8a893eb9bef51c6e35730cebf234d5d0b6bd112b0271e10ed7c24a02bd9"}, + {file = "tqdm-4.66.2.tar.gz", hash = "sha256:6cd52cdf0fef0e0f543299cfc96fec90d7b8a7e88745f411ec33eb44d5ed3531"}, ] [package.dependencies] @@ -4571,28 +4741,28 @@ telegram = ["requests"] [[package]] name = "traitlets" -version = "5.14.1" +version = "5.14.2" description = "Traitlets Python configuration system" optional = false python-versions = ">=3.8" files = [ - {file = "traitlets-5.14.1-py3-none-any.whl", hash = "sha256:2e5a030e6eff91737c643231bfcf04a65b0132078dad75e4936700b213652e74"}, - {file = "traitlets-5.14.1.tar.gz", hash = "sha256:8585105b371a04b8316a43d5ce29c098575c2e477850b62b848b964f1444527e"}, + {file = "traitlets-5.14.2-py3-none-any.whl", hash = "sha256:fcdf85684a772ddeba87db2f398ce00b40ff550d1528c03c14dbf6a02003cd80"}, + {file = "traitlets-5.14.2.tar.gz", hash = "sha256:8cdd83c040dab7d1dee822678e5f5d100b514f7b72b01615b26fc5718916fdf9"}, ] [package.extras] docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] -test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"] +test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.1)", "pytest-mock", "pytest-mypy-testing"] [[package]] name = "transformers" -version = "4.37.2" +version = "4.39.2" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.37.2-py3-none-any.whl", hash = "sha256:595a8b12a1fcc4ad0ced49ce206c58e17be68c85d7aee3d7546d04a32c910d2e"}, - {file = "transformers-4.37.2.tar.gz", hash = "sha256:f307082ae5d528b8480611a4879a4a11651012d0e9aaea3f6cf17219ffd95542"}, + {file = "transformers-4.39.2-py3-none-any.whl", hash = "sha256:8388a4ae1d91ade935f5c5b36dc47aa1a352b092c30595e3337b49a5f7e71b4e"}, + {file = "transformers-4.39.2.tar.gz", hash = "sha256:be0c7392cb92ab48efab2656f1cfd1cbda33b2b8a2917a18bd1196707dbebe14"}, ] [package.dependencies] @@ -4609,16 +4779,16 @@ tqdm = ">=4.27" [package.extras] accelerate = ["accelerate (>=0.21.0)"] -agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.11,!=1.12.0)"] -all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.11,!=1.12.0)", "torchaudio", "torchvision"] +agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch", "torchaudio", "torchvision"] audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] codecarbon = ["codecarbon (==1.2.0)"] deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"] -deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.11,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.19)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.11,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -docs = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.11,!=1.12.0)", "torchaudio", "torchvision"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.19)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +docs = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch", "torchaudio", "torchvision"] docs-specific = ["hf-doc-builder"] flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] @@ -4635,62 +4805,60 @@ ray = ["ray[tune] (>=2.7.0)"] retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] sagemaker = ["sagemaker (>=2.31.0)"] sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] -serving = ["fastapi", "pydantic (<2)", "starlette", "uvicorn"] +serving = ["fastapi", "pydantic", "starlette", "uvicorn"] sigopt = ["sigopt"] sklearn = ["scikit-learn"] speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pydantic (<2)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "tensorboard", "timeout-decorator"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "tensorboard", "timeout-decorator"] tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] timm = ["timm"] tokenizers = ["tokenizers (>=0.14,<0.19)"] -torch = ["accelerate (>=0.21.0)", "torch (>=1.11,!=1.12.0)"] +torch = ["accelerate (>=0.21.0)", "torch"] torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.19)", "torch (>=1.11,!=1.12.0)", "tqdm (>=4.27)"] +torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.19)", "torch", "tqdm (>=4.27)"] video = ["av (==9.2.0)", "decord (==0.6.0)"] vision = ["Pillow (>=10.0.1,<=15.0)"] [[package]] name = "triton" -version = "2.1.0" +version = "2.2.0" description = "A language and compiler for custom Deep Learning operations" optional = false python-versions = "*" files = [ - {file = "triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:66439923a30d5d48399b08a9eae10370f6c261a5ec864a64983bae63152d39d7"}, - {file = "triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:919b06453f0033ea52c13eaf7833de0e57db3178d23d4e04f9fc71c4f2c32bf8"}, - {file = "triton-2.1.0-0-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ae4bb8a91de790e1866405211c4d618379781188f40d5c4c399766914e84cd94"}, - {file = "triton-2.1.0-0-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:39f6fb6bdccb3e98f3152e3fbea724f1aeae7d749412bbb1fa9c441d474eba26"}, - {file = "triton-2.1.0-0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:21544e522c02005a626c8ad63d39bdff2f31d41069592919ef281e964ed26446"}, - {file = "triton-2.1.0-0-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:143582ca31dd89cd982bd3bf53666bab1c7527d41e185f9e3d8a3051ce1b663b"}, - {file = "triton-2.1.0-0-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:82fc5aeeedf6e36be4e4530cbdcba81a09d65c18e02f52dc298696d45721f3bd"}, - {file = "triton-2.1.0-0-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:81a96d110a738ff63339fc892ded095b31bd0d205e3aace262af8400d40b6fa8"}, + {file = "triton-2.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2294514340cfe4e8f4f9e5c66c702744c4a117d25e618bd08469d0bfed1e2e5"}, + {file = "triton-2.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da58a152bddb62cafa9a857dd2bc1f886dbf9f9c90a2b5da82157cd2b34392b0"}, + {file = "triton-2.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0af58716e721460a61886668b205963dc4d1e4ac20508cc3f623aef0d70283d5"}, + {file = "triton-2.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8fe46d3ab94a8103e291bd44c741cc294b91d1d81c1a2888254cbf7ff846dab"}, + {file = "triton-2.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8ce26093e539d727e7cf6f6f0d932b1ab0574dc02567e684377630d86723ace"}, + {file = "triton-2.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:227cc6f357c5efcb357f3867ac2a8e7ecea2298cd4606a8ba1e931d1d5a947df"}, ] [package.dependencies] filelock = "*" [package.extras] -build = ["cmake (>=3.18)", "lit"] -tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"] -tutorials = ["matplotlib", "pandas", "tabulate"] +build = ["cmake (>=3.20)", "lit"] +tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"] +tutorials = ["matplotlib", "pandas", "tabulate", "torch"] [[package]] name = "typeguard" -version = "4.1.5" +version = "4.2.1" description = "Run-time type checker for Python" optional = false python-versions = ">=3.8" files = [ - {file = "typeguard-4.1.5-py3-none-any.whl", hash = "sha256:8923e55f8873caec136c892c3bed1f676eae7be57cdb94819281b3d3bc9c0953"}, - {file = "typeguard-4.1.5.tar.gz", hash = "sha256:ea0a113bbc111bcffc90789ebb215625c963411f7096a7e9062d4e4630c155fd"}, + {file = "typeguard-4.2.1-py3-none-any.whl", hash = "sha256:7da3bd46e61f03e0852f8d251dcbdc2a336aa495d7daff01e092b55327796eb8"}, + {file = "typeguard-4.2.1.tar.gz", hash = "sha256:c556a1b95948230510070ca53fa0341fb0964611bd05d598d87fb52115d65fee"}, ] [package.dependencies] importlib-metadata = {version = ">=3.6", markers = "python_version < \"3.10\""} -typing-extensions = {version = ">=4.7.0", markers = "python_version < \"3.12\""} +typing-extensions = {version = ">=4.10.0", markers = "python_version < \"3.13\""} [package.extras] doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)"] @@ -4698,45 +4866,42 @@ test = ["coverage[toml] (>=7)", "mypy (>=1.2.0)", "pytest (>=7)"] [[package]] name = "typer" -version = "0.9.0" +version = "0.11.0" description = "Typer, build great CLIs. Easy to code. Based on Python type hints." optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" files = [ - {file = "typer-0.9.0-py3-none-any.whl", hash = "sha256:5d96d986a21493606a358cae4461bd8cdf83cbf33a5aa950ae629ca3b51467ee"}, - {file = "typer-0.9.0.tar.gz", hash = "sha256:50922fd79aea2f4751a8e0408ff10d2662bd0c8bbfa84755a699f3bada2978b2"}, + {file = "typer-0.11.0-py3-none-any.whl", hash = "sha256:049cc47bef39f46b043eddd9165492209fdd9bc7d79afa7ba9cc5cd017caa817"}, + {file = "typer-0.11.0.tar.gz", hash = "sha256:a6ce173c0f03d3a41b49c0a945874cc489e91f88faabf76517b2b91c670fcde7"}, ] [package.dependencies] -click = ">=7.1.1,<9.0.0" +click = ">=8.0.0" typing-extensions = ">=3.7.4.3" [package.extras] all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] -dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"] -doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"] -test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] [[package]] name = "types-python-dateutil" -version = "2.8.19.20240106" +version = "2.9.0.20240316" description = "Typing stubs for python-dateutil" optional = false python-versions = ">=3.8" files = [ - {file = "types-python-dateutil-2.8.19.20240106.tar.gz", hash = "sha256:1f8db221c3b98e6ca02ea83a58371b22c374f42ae5bbdf186db9c9a76581459f"}, - {file = "types_python_dateutil-2.8.19.20240106-py3-none-any.whl", hash = "sha256:efbbdc54590d0f16152fa103c9879c7d4a00e82078f6e2cf01769042165acaa2"}, + {file = "types-python-dateutil-2.9.0.20240316.tar.gz", hash = "sha256:5d2f2e240b86905e40944dd787db6da9263f0deabef1076ddaed797351ec0202"}, + {file = "types_python_dateutil-2.9.0.20240316-py3-none-any.whl", hash = "sha256:6b8cb66d960771ce5ff974e9dd45e38facb81718cc1e208b10b1baccbfdbee3b"}, ] [[package]] name = "typing-extensions" -version = "4.9.0" +version = "4.10.0" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.9.0-py3-none-any.whl", hash = "sha256:af72aea155e91adfc61c3ae9e0e342dbc0cba726d6cba4b6c72c1f34e47291cd"}, - {file = "typing_extensions-4.9.0.tar.gz", hash = "sha256:23478f88c37f27d76ac8aee6c905017a143b0b1b886c3c9f66bc2fd94f9f5783"}, + {file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"}, + {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, ] [[package]] @@ -4756,13 +4921,13 @@ typing-extensions = ">=3.7.4" [[package]] name = "tzdata" -version = "2023.4" +version = "2024.1" description = "Provider of IANA time zone data" optional = false python-versions = ">=2" files = [ - {file = "tzdata-2023.4-py2.py3-none-any.whl", hash = "sha256:aa3ace4329eeacda5b7beb7ea08ece826c28d761cda36e747cfbf97996d39bf3"}, - {file = "tzdata-2023.4.tar.gz", hash = "sha256:dd54c94f294765522c77399649b4fefd95522479a664a0cec87f41bebc6148c9"}, + {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, + {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, ] [[package]] @@ -4781,29 +4946,30 @@ dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake [[package]] name = "urllib3" -version = "2.1.0" +version = "2.2.1" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" files = [ - {file = "urllib3-2.1.0-py3-none-any.whl", hash = "sha256:55901e917a5896a349ff771be919f8bd99aff50b79fe58fec595eb37bbc56bb3"}, - {file = "urllib3-2.1.0.tar.gz", hash = "sha256:df7aa8afb0148fa78488e7899b2c59b5f4ffcfa82e6c54ccb9dd37c1d7b52d54"}, + {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"}, + {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"}, ] [package.extras] brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] [[package]] name = "wandb" -version = "0.16.2" +version = "0.16.5" description = "A CLI and library for interacting with the Weights & Biases API." optional = false python-versions = ">=3.7" files = [ - {file = "wandb-0.16.2-py3-none-any.whl", hash = "sha256:6b119cf3c01f35e7276b62d052128e5320621d182c9eb5796a12cf62a9b3134f"}, - {file = "wandb-0.16.2.tar.gz", hash = "sha256:e40cd79ea6272fe4762a80b9f47b172e141daeb3b56eb9d1e192ebd10752e64e"}, + {file = "wandb-0.16.5-py3-none-any.whl", hash = "sha256:023b6c72a6ef13085c9a970f6714548eca64f56d3d8698e42372764950dfd004"}, + {file = "wandb-0.16.5.tar.gz", hash = "sha256:c317d55af93a688f3eafcdfec897f7b72da1fe1525140e076ecdaab8b09aa46e"}, ] [package.dependencies] @@ -4829,11 +4995,13 @@ async = ["httpx (>=0.23.0)"] aws = ["boto3"] azure = ["azure-identity", "azure-storage-blob"] gcp = ["google-cloud-storage"] +importers = ["filelock", "mlflow", "polars", "rich", "tenacity"] kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"] -launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "typing-extensions"] +launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "tomli", "typing-extensions"] media = ["bokeh", "moviepy", "numpy", "pillow", "plotly (>=5.18.0)", "rdkit-pypi", "soundfile"] models = ["cloudpickle"] perf = ["orjson"] +reports = ["pydantic (>=2.0.0)"] sweeps = ["sweeps (>=0.2.0)"] [[package]] @@ -4891,13 +5059,13 @@ test = ["websockets"] [[package]] name = "widgetsnbextension" -version = "4.0.9" +version = "4.0.10" description = "Jupyter interactive widgets for Jupyter Notebook" optional = false python-versions = ">=3.7" files = [ - {file = "widgetsnbextension-4.0.9-py3-none-any.whl", hash = "sha256:91452ca8445beb805792f206e560c1769284267a30ceb1cec9f5bcc887d15175"}, - {file = "widgetsnbextension-4.0.9.tar.gz", hash = "sha256:3c1f5e46dc1166dfd40a42d685e6a51396fd34ff878742a3e47c6f0cc4a2a385"}, + {file = "widgetsnbextension-4.0.10-py3-none-any.whl", hash = "sha256:d37c3724ec32d8c48400a435ecfa7d3e259995201fbefa37163124a9fcb393cc"}, + {file = "widgetsnbextension-4.0.10.tar.gz", hash = "sha256:64196c5ff3b9a9183a8e699a4227fb0b7002f252c814098e66c4d1cd0644688f"}, ] [[package]] @@ -5122,20 +5290,20 @@ multidict = ">=4.0" [[package]] name = "zipp" -version = "3.17.0" +version = "3.18.1" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" files = [ - {file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"}, - {file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"}, + {file = "zipp-3.18.1-py3-none-any.whl", hash = "sha256:206f5a15f2af3dbaee80769fb7dc6f249695e940acca08dfb2a4769fe61e538b"}, + {file = "zipp-3.18.1.tar.gz", hash = "sha256:2884ed22e7d8961de1c9a05142eb69a247f120291bc0206a00a7642f09b5b715"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "1ef3e46351ab989160cd31387ea5dcc887ba643de8a6f5329ffe0dbbbf16fdc7" +content-hash = "08a474ca3da4e9c666274da63409e0912e777c9b925cc06b4872f04bbc6a868c" diff --git a/pyproject.toml b/pyproject.toml index dd74eddc5..84b46ea4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ typing-extensions="*" wandb=">=0.13.5" better-abc = "^0.0.3" + sentencepiece = "*" [tool.poetry.group] [tool.poetry.group.dev.dependencies] diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index fa408bc96..72affb15f 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -171,6 +171,10 @@ "google/gemma-7b", "google/gemma-2b-it", "google/gemma-7b-it", + "01-ai/Yi-6B", + "01-ai/Yi-34B", + "01-ai/Yi-6B-Chat", + "01-ai/Yi-34B-Chat", ] """Official model names for models on HuggingFace.""" @@ -575,6 +579,10 @@ "google/gemma-7b": ["gemma-7b"], "google/gemma-2b-it": ["gemma-2b-it"], "google/gemma-7b-it": ["gemma-7b-it"], + "01-ai/Yi-6B": ["yi-6b", "Yi-6B"], + "01-ai/Yi-34B": ["yi-34b", "Yi-34B"], + "01-ai/Yi-6B-Chat": ["yi-6b-chat", "Yi-6B-Chat"], + "01-ai/Yi-34B-Chat": ["yi-34b-chat", "Yi-34B-Chat"], } """Model aliases for models on HuggingFace.""" @@ -924,6 +932,30 @@ def convert_hf_model_config(model_name: str, **kwargs): "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, "normalization_type": "LN", } + elif architecture == "LlamaForCausalLM": + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": hf_config.max_position_embeddings, + "eps": hf_config.rms_norm_eps, + "d_vocab": hf_config.vocab_size, + "act_fn": hf_config.hidden_act, + "n_key_value_heads": hf_config.num_key_value_heads + if hf_config.num_key_value_heads != hf_config.num_attention_heads + else None, + # This is done because the current implementation of GQA will use Grouped-Query Attention if + # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as + # the same as hf_config.num_attention_heads, in which case GQA should not be used. + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "rotary_adjacent_pairs": False, + "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, + "final_rms": True, + "gated_mlp": True, + } elif architecture == "QWenLMHeadModel": cfg_dict = { "d_model": hf_config.hidden_size, @@ -1662,6 +1694,8 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig): state_dict["embed.W_E"] = llama.model.embed_tokens.weight + # Some models with the Llama architecture use Grouped Query Attention, and so for these we need to modify + # the state dict keys for the K/V attention weight/biases, prepending "_" to the key names. using_gqa = cfg.n_key_value_heads is not None gqa_uscore = "_" if using_gqa else "" From de9a70b0b2b23e46f493f6a0ac669e4424a208af Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 3 Apr 2024 02:23:16 +0200 Subject: [PATCH 47/73] updated docs to account for additional test suites (#533) --- docs/source/content/contributing.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/content/contributing.md b/docs/source/content/contributing.md index f90c0be53..544353aa1 100644 --- a/docs/source/content/contributing.md +++ b/docs/source/content/contributing.md @@ -32,6 +32,8 @@ quite slow (as we only have CPU actions) so the smaller models like `attn-only-1 - Unit tests only via `make unit-test` - Acceptance tests only via `make acceptance-test` - Docstring tests only via `make docstring-test` +- Notebook tests only via `make notebook-test` +- Run all test suites mentioned `make test` ## Formatting From bae79771240aa9d5c38a76df99a2f554ef798bc1 Mon Sep 17 00:00:00 2001 From: Toni Kukurin Date: Thu, 4 Apr 2024 01:08:51 +0200 Subject: [PATCH 48/73] bugfix subscripted generics (#534) --- transformer_lens/ActivationCache.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index cf25f4eeb..83b5b4110 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -10,6 +10,7 @@ class first, including the examples, and then skimming the available methods. You can then refer back to these docs depending on what you need to do. """ + from __future__ import annotations import logging @@ -830,10 +831,8 @@ def get_neuron_results( Tensor of the results. """ if type(neuron_slice) is not Slice: - assert isinstance(neuron_slice, SliceInput) neuron_slice = Slice(neuron_slice) if type(pos_slice) is not Slice: - assert isinstance(pos_slice, SliceInput) pos_slice = Slice(pos_slice) neuron_acts = self[("post", layer, "mlp")] From f052f3955a478ddffdcbbb5d1187aa1c682aef68 Mon Sep 17 00:00:00 2001 From: Pavan Katta Date: Sat, 6 Apr 2024 03:41:52 +0530 Subject: [PATCH 49/73] Fix platform markers (#510) Co-authored-by: Bryce Meyer --- poetry.lock | 233 ++++++++++++++++++++++++++----------------------- pyproject.toml | 5 +- 2 files changed, 130 insertions(+), 108 deletions(-) diff --git a/poetry.lock b/poetry.lock index c5e5c03a3..0d13a6dd7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,13 +2,13 @@ [[package]] name = "accelerate" -version = "0.28.0" +version = "0.29.1" description = "Accelerate" optional = false python-versions = ">=3.8.0" files = [ - {file = "accelerate-0.28.0-py3-none-any.whl", hash = "sha256:8ae25f8a8dc4cf12283842c469113836300545fb0dfa46fef331fb0a2ac8b421"}, - {file = "accelerate-0.28.0.tar.gz", hash = "sha256:32019a49f4b3a85cc179ac4e38e9e2971f1a997dee026be0512816499464c4d5"}, + {file = "accelerate-0.29.1-py3-none-any.whl", hash = "sha256:7eda0c8bc62bc59129103310f1272a0fb7b3ebc55fc8920cfe1c102db30aca58"}, + {file = "accelerate-0.29.1.tar.gz", hash = "sha256:d1d0e5a591177891812fd6d1bc843af191e1192c80e5180258f52fefcb653a9f"}, ] [package.dependencies] @@ -21,14 +21,14 @@ safetensors = ">=0.3.1" torch = ">=1.10.0" [package.extras] -dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "deepspeed (<0.13.0)", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.2.1,<0.3.0)", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] +dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "deepspeed", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.2.1,<0.3.0)", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.2.1,<0.3.0)"] rich = ["rich"] sagemaker = ["sagemaker"] -test-dev = ["bitsandbytes", "datasets", "deepspeed (<0.13.0)", "evaluate", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] +test-dev = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] test-prod = ["parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist"] test-trackers = ["comet-ml", "dvclive", "tensorboard", "wandb"] -testing = ["bitsandbytes", "datasets", "deepspeed (<0.13.0)", "evaluate", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] +testing = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] [[package]] name = "aiohttp" @@ -651,22 +651,35 @@ files = [ [[package]] name = "circuitsvis" -version = "1.41.0" +version = "1.43.2" description = "Mechanistic Interpretability Visualizations" optional = false -python-versions = ">=3.7,<4.0" +python-versions = ">=3.8" files = [ - {file = "circuitsvis-1.41.0-py3-none-any.whl", hash = "sha256:53dc12c955c160b8108a0eb17ed14a34ba9f53b218457d29f351cba3db31acb7"}, - {file = "circuitsvis-1.41.0.tar.gz", hash = "sha256:386385f38d8b9de1bbef125fa282afc9157027bc2dcdc4c04feafbc22bc71d17"}, + {file = "circuitsvis-1.43.2-py3-none-any.whl", hash = "sha256:1128fde5de8b738dd3c932d0b0ec4ee5556387b4405592fdf37f617e647183fb"}, + {file = "circuitsvis-1.43.2.tar.gz", hash = "sha256:388c1a6ea1bcf308da51fa6f67be761483ba361321d2e111f4c28faaea458287"}, ] [package.dependencies] -importlib-metadata = ">=5.1.0,<6.0.0" +importlib-metadata = ">=5.1.0" numpy = [ - {version = ">=1.21,<2.0", markers = "python_version < \"3.10\""}, - {version = ">=1.23,<2.0", markers = "python_version >= \"3.10\""}, + {version = ">=1.20,<1.25", markers = "python_version >= \"3.8\" and python_version < \"3.9\""}, + {version = ">=1.24", markers = "python_version >= \"3.9\" and python_version < \"3.12\""}, + {version = ">=1.26", markers = "python_version >= \"3.12\" and python_version < \"3.13\""}, ] -torch = {version = ">=1.10", markers = "python_version >= \"3.8\""} +nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu12 = {version = "2.18.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +torch = ">=1.10" +triton = {version = "2.1.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} [[package]] name = "click" @@ -1162,20 +1175,21 @@ smmap = ">=3.0.1,<6" [[package]] name = "gitpython" -version = "3.1.42" +version = "3.1.43" description = "GitPython is a Python library used to interact with Git repositories" optional = false python-versions = ">=3.7" files = [ - {file = "GitPython-3.1.42-py3-none-any.whl", hash = "sha256:1bf9cd7c9e7255f77778ea54359e54ac22a72a5b51288c457c881057b7bb9ecd"}, - {file = "GitPython-3.1.42.tar.gz", hash = "sha256:2d99869e0fef71a73cbd242528105af1d6c1b108c60dfabd994bf292f76c3ceb"}, + {file = "GitPython-3.1.43-py3-none-any.whl", hash = "sha256:eec7ec56b92aad751f9912a73404bc02ba212a23adb2c7098ee668417051a1ff"}, + {file = "GitPython-3.1.43.tar.gz", hash = "sha256:35f314a9f878467f5453cc1fee295c3e18e52f1b99f10f6cf5b1682e968a9e7c"}, ] [package.dependencies] gitdb = ">=4.0.1,<5" [package.extras] -test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar"] +doc = ["sphinx (==4.3.2)", "sphinx-autodoc-typehints", "sphinx-rtd-theme", "sphinxcontrib-applehelp (>=1.0.2,<=1.0.4)", "sphinxcontrib-devhelp (==1.0.2)", "sphinxcontrib-htmlhelp (>=2.0.0,<=2.0.1)", "sphinxcontrib-qthelp (==1.0.3)", "sphinxcontrib-serializinghtml (==1.1.5)"] +test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions"] [[package]] name = "h11" @@ -1235,13 +1249,13 @@ socks = ["socksio (==1.*)"] [[package]] name = "huggingface-hub" -version = "0.22.1" +version = "0.22.2" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.22.1-py3-none-any.whl", hash = "sha256:eac63947923d15c9a68681d7ed2d9599e058860617064e3ee6bd91a4b954faaf"}, - {file = "huggingface_hub-0.22.1.tar.gz", hash = "sha256:5b8aaee5f3618cd432f49886da9935bbe8fab92d719011826430907b93171dd8"}, + {file = "huggingface_hub-0.22.2-py3-none-any.whl", hash = "sha256:3429e25f38ccb834d310804a3b711e7e4953db5a9e420cc147a5e194ca90fd17"}, + {file = "huggingface_hub-0.22.2.tar.gz", hash = "sha256:32e9a9a6843c92f253ff9ca16b9985def4d80a93fb357af5353f770ef74a81be"}, ] [package.dependencies] @@ -1291,22 +1305,22 @@ files = [ [[package]] name = "importlib-metadata" -version = "5.2.0" +version = "7.1.0" description = "Read metadata from Python packages" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "importlib_metadata-5.2.0-py3-none-any.whl", hash = "sha256:0eafa39ba42bf225fc00e67f701d71f85aead9f878569caf13c3724f704b970f"}, - {file = "importlib_metadata-5.2.0.tar.gz", hash = "sha256:404d48d62bba0b7a77ff9d405efd91501bef2e67ff4ace0bed40a0cf28c3c7cd"}, + {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"}, + {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"}, ] [package.dependencies] zipp = ">=0.5" [package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] [[package]] name = "importlib-resources" @@ -2341,19 +2355,19 @@ webpdf = ["playwright"] [[package]] name = "nbformat" -version = "5.10.3" +version = "5.10.4" description = "The Jupyter Notebook format" optional = false python-versions = ">=3.8" files = [ - {file = "nbformat-5.10.3-py3-none-any.whl", hash = "sha256:d9476ca28676799af85385f409b49d95e199951477a159a576ef2a675151e5e8"}, - {file = "nbformat-5.10.3.tar.gz", hash = "sha256:60ed5e910ef7c6264b87d644f276b1b49e24011930deef54605188ddeb211685"}, + {file = "nbformat-5.10.4-py3-none-any.whl", hash = "sha256:3b48d6c8fbca4b299bf3982ea7db1af21580e4fec269ad087b9e81588891200b"}, + {file = "nbformat-5.10.4.tar.gz", hash = "sha256:322168b14f937a5d11362988ecac2a4952d3d8e3a2cbeb2319584631226d5b3a"}, ] [package.dependencies] -fastjsonschema = "*" +fastjsonschema = ">=2.15" jsonschema = ">=2.6" -jupyter-core = "*" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" traitlets = ">=5.1" [package.extras] @@ -2659,24 +2673,23 @@ nvidia-nvjitlink-cu12 = "*" [[package]] name = "nvidia-nccl-cu12" -version = "2.19.3" +version = "2.18.1" description = "NVIDIA Collective Communication Library (NCCL) Runtime" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl", hash = "sha256:a9734707a2c96443331c1e48c717024aa6678a0e2a4cb66b2c364d18cee6b48d"}, + {file = "nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:1a6c4acefcbebfa6de320f412bf7866de856e786e0462326ba1bac40de0b5e71"}, ] [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.4.99" +version = "12.4.127" description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_aarch64.whl", hash = "sha256:75d6498c96d9adb9435f2bbdbddb479805ddfb97b5c1b32395c694185c20ca57"}, - {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c6428836d20fe7e327191c175791d38570e10762edc588fb46749217cd444c74"}, - {file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-win_amd64.whl", hash = "sha256:991905ffa2144cb603d8ca7962d75c35334ae82bf92820b6ba78157277da1ad2"}, + {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"}, + {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"}, ] [[package]] @@ -2806,18 +2819,18 @@ files = [ [[package]] name = "parso" -version = "0.8.3" +version = "0.8.4" description = "A Python Parser" optional = false python-versions = ">=3.6" files = [ - {file = "parso-0.8.3-py2.py3-none-any.whl", hash = "sha256:c001d4636cd3aecdaf33cbb40aebb59b094be2a74c556778ef5576c175e19e75"}, - {file = "parso-0.8.3.tar.gz", hash = "sha256:8c07be290bb59f03588915921e29e8a50002acaf2cdc5fa0e0114f91709fafa0"}, + {file = "parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18"}, + {file = "parso-0.8.4.tar.gz", hash = "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d"}, ] [package.extras] -qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] -testing = ["docopt", "pytest (<6.0.0)"] +qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] +testing = ["docopt", "pytest"] [[package]] name = "pathspec" @@ -3135,13 +3148,13 @@ typer = ">=0.4.1" [[package]] name = "pycparser" -version = "2.21" +version = "2.22" description = "C parser in Python" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +python-versions = ">=3.8" files = [ - {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, - {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, + {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, + {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, ] [[package]] @@ -3317,7 +3330,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4001,13 +4013,13 @@ files = [ [[package]] name = "sentry-sdk" -version = "1.44.0" +version = "1.44.1" description = "Python client for Sentry (https://sentry.io)" optional = false python-versions = "*" files = [ - {file = "sentry-sdk-1.44.0.tar.gz", hash = "sha256:f7125a9235795811962d52ff796dc032cd1d0dd98b59beaced8380371cd9c13c"}, - {file = "sentry_sdk-1.44.0-py2.py3-none-any.whl", hash = "sha256:eb65289da013ca92fad2694851ad2f086aa3825e808dc285bd7dcaf63602bb18"}, + {file = "sentry-sdk-1.44.1.tar.gz", hash = "sha256:24e6a53eeabffd2f95d952aa35ca52f0f4201d17f820ac9d3ff7244c665aaf68"}, + {file = "sentry_sdk-1.44.1-py2.py3-none-any.whl", hash = "sha256:5f75eb91d8ab6037c754a87b8501cc581b2827e923682f593bed3539ce5b3999"}, ] [package.dependencies] @@ -4162,6 +4174,17 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] +[[package]] +name = "shellingham" +version = "1.5.4" +description = "Tool to Detect Surrounding Shell" +optional = false +python-versions = ">=3.7" +files = [ + {file = "shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686"}, + {file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"}, +] + [[package]] name = "six" version = "1.16.0" @@ -4643,36 +4666,31 @@ files = [ [[package]] name = "torch" -version = "2.2.2" +version = "2.1.2" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.2.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:bc889d311a855dd2dfd164daf8cc903a6b7273a747189cebafdd89106e4ad585"}, - {file = "torch-2.2.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:15dffa4cc3261fa73d02f0ed25f5fa49ecc9e12bf1ae0a4c1e7a88bbfaad9030"}, - {file = "torch-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:11e8fe261233aeabd67696d6b993eeb0896faa175c6b41b9a6c9f0334bdad1c5"}, - {file = "torch-2.2.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:b2e2200b245bd9f263a0d41b6a2dab69c4aca635a01b30cca78064b0ef5b109e"}, - {file = "torch-2.2.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:877b3e6593b5e00b35bbe111b7057464e76a7dd186a287280d941b564b0563c2"}, - {file = "torch-2.2.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:ad4c03b786e074f46606f4151c0a1e3740268bcf29fbd2fdf6666d66341c1dcb"}, - {file = "torch-2.2.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:32827fa1fbe5da8851686256b4cd94cc7b11be962862c2293811c94eea9457bf"}, - {file = "torch-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:f9ef0a648310435511e76905f9b89612e45ef2c8b023bee294f5e6f7e73a3e7c"}, - {file = "torch-2.2.2-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:95b9b44f3bcebd8b6cd8d37ec802048c872d9c567ba52c894bba90863a439059"}, - {file = "torch-2.2.2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:49aa4126ede714c5aeef7ae92969b4b0bbe67f19665106463c39f22e0a1860d1"}, - {file = "torch-2.2.2-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:cf12cdb66c9c940227ad647bc9cf5dba7e8640772ae10dfe7569a0c1e2a28aca"}, - {file = "torch-2.2.2-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:89ddac2a8c1fb6569b90890955de0c34e1724f87431cacff4c1979b5f769203c"}, - {file = "torch-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:451331406b760f4b1ab298ddd536486ab3cfb1312614cfe0532133535be60bea"}, - {file = "torch-2.2.2-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:eb4d6e9d3663e26cd27dc3ad266b34445a16b54908e74725adb241aa56987533"}, - {file = "torch-2.2.2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:bf9558da7d2bf7463390b3b2a61a6a3dbb0b45b161ee1dd5ec640bf579d479fc"}, - {file = "torch-2.2.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cd2bf7697c9e95fb5d97cc1d525486d8cf11a084c6af1345c2c2c22a6b0029d0"}, - {file = "torch-2.2.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:b421448d194496e1114d87a8b8d6506bce949544e513742b097e2ab8f7efef32"}, - {file = "torch-2.2.2-cp38-cp38-win_amd64.whl", hash = "sha256:3dbcd563a9b792161640c0cffe17e3270d85e8f4243b1f1ed19cca43d28d235b"}, - {file = "torch-2.2.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:31f4310210e7dda49f1fb52b0ec9e59382cfcb938693f6d5378f25b43d7c1d29"}, - {file = "torch-2.2.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:c795feb7e8ce2e0ef63f75f8e1ab52e7fd5e1a4d7d0c31367ade1e3de35c9e95"}, - {file = "torch-2.2.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:a6e5770d68158d07456bfcb5318b173886f579fdfbf747543901ce718ea94782"}, - {file = "torch-2.2.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:67dcd726edff108e2cd6c51ff0e416fd260c869904de95750e80051358680d24"}, - {file = "torch-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:539d5ef6c4ce15bd3bd47a7b4a6e7c10d49d4d21c0baaa87c7d2ef8698632dfb"}, - {file = "torch-2.2.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:dff696de90d6f6d1e8200e9892861fd4677306d0ef604cb18f2134186f719f82"}, - {file = "torch-2.2.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:3a4dd910663fd7a124c056c878a52c2b0be4a5a424188058fe97109d4436ee42"}, + {file = "torch-2.1.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:3a871edd6c02dae77ad810335c0833391c1a4ce49af21ea8cf0f6a5d2096eea8"}, + {file = "torch-2.1.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:bef6996c27d8f6e92ea4e13a772d89611da0e103b48790de78131e308cf73076"}, + {file = "torch-2.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:0e13034fd5fb323cbbc29e56d0637a3791e50dd589616f40c79adfa36a5a35a1"}, + {file = "torch-2.1.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:d9b535cad0df3d13997dbe8bd68ac33e0e3ae5377639c9881948e40794a61403"}, + {file = "torch-2.1.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:f9a55d55af02826ebfbadf4e9b682f0f27766bc33df8236b48d28d705587868f"}, + {file = "torch-2.1.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:a6ebbe517097ef289cc7952783588c72de071d4b15ce0f8b285093f0916b1162"}, + {file = "torch-2.1.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:8f32ce591616a30304f37a7d5ea80b69ca9e1b94bba7f308184bf616fdaea155"}, + {file = "torch-2.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:e0ee6cf90c8970e05760f898d58f9ac65821c37ffe8b04269ec787aa70962b69"}, + {file = "torch-2.1.2-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:76d37967c31c99548ad2c4d3f2cf191db48476f2e69b35a0937137116da356a1"}, + {file = "torch-2.1.2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:e2d83f07b4aac983453ea5bf8f9aa9dacf2278a8d31247f5d9037f37befc60e4"}, + {file = "torch-2.1.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:f41fe0c7ecbf903a568c73486139a75cfab287a0f6c17ed0698fdea7a1e8641d"}, + {file = "torch-2.1.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:e3225f47d50bb66f756fe9196a768055d1c26b02154eb1f770ce47a2578d3aa7"}, + {file = "torch-2.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:33d59cd03cb60106857f6c26b36457793637512998666ee3ce17311f217afe2b"}, + {file = "torch-2.1.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:8e221deccd0def6c2badff6be403e0c53491805ed9915e2c029adbcdb87ab6b5"}, + {file = "torch-2.1.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:05b18594f60a911a0c4f023f38a8bda77131fba5fd741bda626e97dcf5a3dd0a"}, + {file = "torch-2.1.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9ca96253b761e9aaf8e06fb30a66ee301aecbf15bb5a303097de1969077620b6"}, + {file = "torch-2.1.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:d93ba70f67b08c2ae5598ee711cbc546a1bc8102cef938904b8c85c2089a51a0"}, + {file = "torch-2.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:255b50bc0608db177e6a3cc118961d77de7e5105f07816585fa6f191f33a9ff3"}, + {file = "torch-2.1.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6984cd5057c0c977b3c9757254e989d3f1124f4ce9d07caa6cb637783c71d42a"}, + {file = "torch-2.1.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:bc195d7927feabc0eb7c110e457c955ed2ab616f3c7c28439dd4188cf589699f"}, ] [package.dependencies] @@ -4689,15 +4707,15 @@ nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linu nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nccl-cu12 = {version = "2.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu12 = {version = "2.18.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} sympy = "*" -triton = {version = "2.2.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} -typing-extensions = ">=4.8.0" +triton = {version = "2.1.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +typing-extensions = "*" [package.extras] +dynamo = ["jinja2"] opt-einsum = ["opt-einsum (>=3.3)"] -optree = ["optree (>=0.9.1)"] [[package]] name = "tornado" @@ -4756,13 +4774,13 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0, [[package]] name = "transformers" -version = "4.39.2" +version = "4.39.3" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.39.2-py3-none-any.whl", hash = "sha256:8388a4ae1d91ade935f5c5b36dc47aa1a352b092c30595e3337b49a5f7e71b4e"}, - {file = "transformers-4.39.2.tar.gz", hash = "sha256:be0c7392cb92ab48efab2656f1cfd1cbda33b2b8a2917a18bd1196707dbebe14"}, + {file = "transformers-4.39.3-py3-none-any.whl", hash = "sha256:7838034a12cca3168247f9d2d1dba6724c9de3ae0f73a108258c6b8fc5912601"}, + {file = "transformers-4.39.3.tar.gz", hash = "sha256:2586e5ff4150f122716fc40f5530e92871befc051848fbe82600969c535b762d"}, ] [package.dependencies] @@ -4824,26 +4842,28 @@ vision = ["Pillow (>=10.0.1,<=15.0)"] [[package]] name = "triton" -version = "2.2.0" +version = "2.1.0" description = "A language and compiler for custom Deep Learning operations" optional = false python-versions = "*" files = [ - {file = "triton-2.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2294514340cfe4e8f4f9e5c66c702744c4a117d25e618bd08469d0bfed1e2e5"}, - {file = "triton-2.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da58a152bddb62cafa9a857dd2bc1f886dbf9f9c90a2b5da82157cd2b34392b0"}, - {file = "triton-2.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0af58716e721460a61886668b205963dc4d1e4ac20508cc3f623aef0d70283d5"}, - {file = "triton-2.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8fe46d3ab94a8103e291bd44c741cc294b91d1d81c1a2888254cbf7ff846dab"}, - {file = "triton-2.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8ce26093e539d727e7cf6f6f0d932b1ab0574dc02567e684377630d86723ace"}, - {file = "triton-2.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:227cc6f357c5efcb357f3867ac2a8e7ecea2298cd4606a8ba1e931d1d5a947df"}, + {file = "triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:66439923a30d5d48399b08a9eae10370f6c261a5ec864a64983bae63152d39d7"}, + {file = "triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:919b06453f0033ea52c13eaf7833de0e57db3178d23d4e04f9fc71c4f2c32bf8"}, + {file = "triton-2.1.0-0-cp37-cp37m-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ae4bb8a91de790e1866405211c4d618379781188f40d5c4c399766914e84cd94"}, + {file = "triton-2.1.0-0-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:39f6fb6bdccb3e98f3152e3fbea724f1aeae7d749412bbb1fa9c441d474eba26"}, + {file = "triton-2.1.0-0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:21544e522c02005a626c8ad63d39bdff2f31d41069592919ef281e964ed26446"}, + {file = "triton-2.1.0-0-pp37-pypy37_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:143582ca31dd89cd982bd3bf53666bab1c7527d41e185f9e3d8a3051ce1b663b"}, + {file = "triton-2.1.0-0-pp38-pypy38_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:82fc5aeeedf6e36be4e4530cbdcba81a09d65c18e02f52dc298696d45721f3bd"}, + {file = "triton-2.1.0-0-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:81a96d110a738ff63339fc892ded095b31bd0d205e3aace262af8400d40b6fa8"}, ] [package.dependencies] filelock = "*" [package.extras] -build = ["cmake (>=3.20)", "lit"] -tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"] -tutorials = ["matplotlib", "pandas", "tabulate", "torch"] +build = ["cmake (>=3.18)", "lit"] +tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"] +tutorials = ["matplotlib", "pandas", "tabulate"] [[package]] name = "typeguard" @@ -4866,22 +4886,21 @@ test = ["coverage[toml] (>=7)", "mypy (>=1.2.0)", "pytest (>=7)"] [[package]] name = "typer" -version = "0.11.0" +version = "0.12.1" description = "Typer, build great CLIs. Easy to code. Based on Python type hints." optional = false python-versions = ">=3.7" files = [ - {file = "typer-0.11.0-py3-none-any.whl", hash = "sha256:049cc47bef39f46b043eddd9165492209fdd9bc7d79afa7ba9cc5cd017caa817"}, - {file = "typer-0.11.0.tar.gz", hash = "sha256:a6ce173c0f03d3a41b49c0a945874cc489e91f88faabf76517b2b91c670fcde7"}, + {file = "typer-0.12.1-py3-none-any.whl", hash = "sha256:43ebb23c8a358c3d623e31064359a65f50229d0bf73ae8dfd203f49d9126ae06"}, + {file = "typer-0.12.1.tar.gz", hash = "sha256:72d218ef3c686aed9c6ff3ca25b238aee0474a1628b29c559b18b634cfdeca88"}, ] [package.dependencies] click = ">=8.0.0" +rich = ">=10.11.0" +shellingham = ">=1.3.0" typing-extensions = ">=3.7.4.3" -[package.extras] -all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] - [[package]] name = "types-python-dateutil" version = "2.9.0.20240316" @@ -4895,13 +4914,13 @@ files = [ [[package]] name = "typing-extensions" -version = "4.10.0" +version = "4.11.0" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"}, - {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, + {file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"}, + {file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"}, ] [[package]] @@ -4963,13 +4982,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "wandb" -version = "0.16.5" +version = "0.16.6" description = "A CLI and library for interacting with the Weights & Biases API." optional = false python-versions = ">=3.7" files = [ - {file = "wandb-0.16.5-py3-none-any.whl", hash = "sha256:023b6c72a6ef13085c9a970f6714548eca64f56d3d8698e42372764950dfd004"}, - {file = "wandb-0.16.5.tar.gz", hash = "sha256:c317d55af93a688f3eafcdfec897f7b72da1fe1525140e076ecdaab8b09aa46e"}, + {file = "wandb-0.16.6-py3-none-any.whl", hash = "sha256:5810019a3b981c796e98ea58557a7c380f18834e0c6bdaed15df115522e5616e"}, + {file = "wandb-0.16.6.tar.gz", hash = "sha256:86f491e3012d715e0d7d7421a4d6de41abef643b7403046261f962f3e512fe1c"}, ] [package.dependencies] @@ -5306,4 +5325,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "08a474ca3da4e9c666274da63409e0912e777c9b925cc06b4872f04bbc6a868c" +content-hash = "8128709b406f2fd78fa6fc77f68878861e7d95fc62acf1e0fc80f46b1026ba56" diff --git a/pyproject.toml b/pyproject.toml index 84b46ea4a..0efa3822f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,10 @@ pandas=">=1.1.5" python=">=3.8,<4.0" rich=">=12.6.0" - torch=">=1.10,!=2.0,!=2.1.0" # Pin >=2.1.1 due to known MPS errors on 2.1.0 + torch = [ + {platform = "linux", version = ">=1.10"}, # We can use any torch version on Linux (e.g colab) + {platform = "!=linux", version = ">=1.10,!=2.0,!=2.1.0"}, # Pin >=2.1.1 on Apple devices due to known MPS errors on 2.1.0 + ] tqdm=">=4.64.1" transformers=">=4.37.2" typing-extensions="*" From 4ca06e7a679bfbfd133a0cd75f0c84e1c096d7c2 Mon Sep 17 00:00:00 2001 From: Lawrence Chan Date: Mon, 8 Apr 2024 03:21:05 -0700 Subject: [PATCH 50/73] Add Xavier and Kaiming Initializations (#537) * make cspell not mad * add new init methods Add in kaiming, xavier, and (incomplete) MuP initializations * Various small typo, comments, and bugfixes * tests for inits * more cspell edits so it's happy * run black with default -l 88 * fix to make docs compile properly * accidently is not a word, whoops --- .vscode/cspell.json | 36 +++- .vscode/settings.json | 4 +- tests/unit/test_tokenization_methods.py | 2 +- tests/unit/test_utils.py | 180 ++++++++++++++++++++ transformer_lens/HookedEncoder.py | 7 + transformer_lens/HookedTransformer.py | 126 ++++++++++++-- transformer_lens/HookedTransformerConfig.py | 24 ++- transformer_lens/components.py | 2 +- transformer_lens/utils.py | 88 ++++++++-- 9 files changed, 437 insertions(+), 32 deletions(-) diff --git a/.vscode/cspell.json b/.vscode/cspell.json index ba26db277..19eedf858 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -1,12 +1,13 @@ { "language": "en,en-GB", "words": [ - "adrià", "accum", + "adrià", "aengus", "alonso", "arange", "argmax", + "argmaxy", "autodiff", "autoregressive", "barez", @@ -15,16 +16,22 @@ "bertsimas", "biderman", "bilal", + "bincount", "caxis", "checkpointed", "chughtai", "circuitsvis", + "Codeparrot", "codespaces", "colab", "collectstart", "colour", "conmy", "cooney", + "crfm", + "cumsum", + "datapoint", + "dictmodel", "dimitris", "disconfirm", "dmitrii", @@ -32,11 +39,14 @@ "doctest", "doctree", "dtype", + "dtypes", "einops", "elhage", + "endoftext", "eqnarray", "esben", "evals", + "explictly", "fazl", "firstpage", "fspath", @@ -50,21 +60,26 @@ "howpublished", "huggingface", "icml", + "idxs", "imshow", "interp", "interpretability", "ioannis", "ipynb", + "isin", "isort", "janiak", + "Janky", "jaxtyping", "jett", + "kaiming", "keepdim", "kissane", "konstas", "kran", "lastpage", "layernorm", + "ldim", "lieberum", "logits", "logsumexp", @@ -72,31 +87,50 @@ "maxdepth", "mingpt", "nanda", + "ndarray", + "ndim", "neel", "neox", "nitpicky", + "occurences", "olah", + "openwebtext", + "overcomplete", + "Overriden", "pagename", "pauly", "pretrained", "probs", + "producting", "pycln", "pypi", "pytest", + "randn", + "rdim", + "relu", "resid", + "rprint", "rtml", + "rtol", + "shortformer", "softmax", + "softmaxing", "solu", + "stas", "templatedir", "templatename", "toctree", "topk", + "tqdm", "transformerlens", + "tril", + "triu", "troitskii", "unembed", "unembedded", "unembedding", "unigram", + "unsqueeze", "virtualenvs", "visualisation", "xaxis", diff --git a/.vscode/settings.json b/.vscode/settings.json index 1c479871d..2fa400667 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -6,7 +6,7 @@ "editor.defaultFormatter": "tamasfe.even-better-toml" }, "editor.codeActionsOnSave": { - "source.organizeImports": true + "source.organizeImports": "explicit" }, "editor.formatOnSave": true, "evenBetterToml.formatter.allowedBlankLines": 1, @@ -38,4 +38,6 @@ "rewrap.autoWrap.enabled": true, "rewrap.reformat": true, "rewrap.wrappingColumn": 100, + "mypy.runUsingActiveInterpreter": true, + "editor.defaultFormatter": "ms-python.black-formatter", } \ No newline at end of file diff --git a/tests/unit/test_tokenization_methods.py b/tests/unit/test_tokenization_methods.py index 83e189c03..d795d453a 100644 --- a/tests/unit/test_tokenization_methods.py +++ b/tests/unit/test_tokenization_methods.py @@ -126,7 +126,7 @@ def test_get_token_position_not_found(): with pytest.raises(AssertionError) as exc_info: model.get_token_position(single, input) assert ( - str(exc_info.value) == "The token does not occur in the prompt" + str(exc_info.value) == f"The token does not occur in the prompt" ), "assertion error" diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 6995a7f00..f08098c26 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -3,6 +3,7 @@ import numpy as np import pytest import torch +from torch import nn import transformer_lens.utils as utils from transformer_lens import HookedTransformer @@ -377,3 +378,182 @@ def test_get_attention_mask( else: # otherwise, there should be no attended but non-pad token assert attended_but_non_pad_mask.sum() == 0 + + +def test_calc_fan_in_fan_out(): + """ + Test for the calc_fan_in_and_fan_out function in the utils module. + """ + # Test for the case when the tensor is 1D + tensor_1d = torch.tensor([1, 2, 3, 4, 5]) + fan_in, fan_out = utils.calc_fan_in_and_fan_out(tensor_1d) + assert fan_in == 1 + assert fan_out == 5 + + # Test for the case when the tensor is 2D + tensor_2d = torch.tensor([[1, 2, 3], [4, 5, 6]]) + fan_in, fan_out = utils.calc_fan_in_and_fan_out(tensor_2d) + assert fan_in == 2 + assert fan_out == 3 + + # Test for the case when the tensor is 3D + tensor_3d = nn.Parameter( + torch.rand(2, 25, 5) + ) # 2 x 25 x 5, I'm not writing this out + fan_in, fan_out = utils.calc_fan_in_and_fan_out(tensor_3d) + assert fan_in == 25 + assert fan_out == 10 + + # Test for the case when the tensor is 4D (should raise a ValueError) + tensor_4d = torch.tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]) + with pytest.raises(ValueError): + fan_in, fan_out = utils.calc_fan_in_and_fan_out(tensor_4d) + + # Test for the case when the tensor is 0D (also should raise a ValueError) + tensor_0d = torch.tensor(1) + with pytest.raises(ValueError): + fan_in, fan_out = utils.calc_fan_in_and_fan_out(tensor_0d) + + +class TestInitKaiming: + """Test cases for kaiming init.""" + + @pytest.mark.parametrize( + "d_model", [4096, 10_000] + ) # this needs to be large so std and min/max estimates are accurate + @pytest.mark.parametrize("d_mlp", [256, 512]) + @pytest.mark.parametrize("nonlinearity", ["linear", "relu"]) + def test_init_kaiming_uniform(self, d_model, d_mlp, nonlinearity): + """ + Test init_kaiming_uniform function in the utils module on 3/2/1D tensors. + """ + torch.manual_seed(1234) + + gain = np.sqrt(2.0) if nonlinearity == "relu" else 1.0 + + x = nn.Parameter(torch.empty(2, d_model, 137)) # n_head and d_head don't matter + utils.init_kaiming_uniform_(x, nonlinearity=nonlinearity) + std = gain / np.sqrt(d_model) + assert np.isclose(x.std().detach().numpy(), std, rtol=1e-2) + # for uniform distributions, min/max is sqrt(3) times the std + assert np.isclose(x.max().detach().numpy(), np.sqrt(3) * std, rtol=1e-2) + assert np.isclose(x.min().detach().numpy(), -np.sqrt(3) * std, rtol=1e-2) + + y = nn.Parameter(torch.empty(d_mlp, d_model)) + utils.init_kaiming_uniform_(y, nonlinearity=nonlinearity) + std = gain / np.sqrt(d_mlp) + assert np.isclose(y.std().detach().numpy(), std, rtol=1e-2) + # for uniform distributions, min/max is sqrt(3) times the std + assert np.isclose(y.max().detach().numpy(), np.sqrt(3) * std, rtol=1e-2) + assert np.isclose(y.min().detach().numpy(), -np.sqrt(3) * std, rtol=1e-2) + + z = nn.Parameter(torch.empty(d_model * 123)) + utils.init_kaiming_uniform_(z, nonlinearity=nonlinearity) + std = gain # bias has fan_in 1 + assert np.isclose(z.std().detach().numpy(), std, rtol=1e-2) + # for uniform distributions, min/max is sqrt(3) times the std + assert np.isclose(z.max().detach().numpy(), np.sqrt(3) * std, rtol=1e-2) + assert np.isclose(z.min().detach().numpy(), -np.sqrt(3) * std, rtol=1e-2) + + torch.manual_seed(1234) + x_new = nn.Parameter(torch.empty(2, d_model, 137)) + utils.init_kaiming_uniform_(x_new, nonlinearity=nonlinearity) + assert torch.allclose(x_new, x, rtol=1e-2) + + @pytest.mark.parametrize("d_model", [4096, 10_000]) + @pytest.mark.parametrize("d_mlp", [256, 512]) + @pytest.mark.parametrize("nonlinearity", ["linear", "relu"]) + def test_init_kaiming_normal(self, d_model, d_mlp, nonlinearity): + """ + Test init_kaiming_normal function in the utils module on 3/2/1D tensors. + """ + torch.manual_seed(1234) + + gain = np.sqrt(2.0) if nonlinearity == "relu" else 1.0 + + x = nn.Parameter(torch.empty(2, d_model, 137)) + utils.init_kaiming_normal_(x, nonlinearity=nonlinearity) + std = gain / np.sqrt(d_model) + assert np.isclose(x.std().detach().numpy(), std, rtol=1e-2) + + y = nn.Parameter(torch.empty(d_mlp, d_model)) + utils.init_kaiming_normal_(y, nonlinearity=nonlinearity) + std = gain / np.sqrt(d_mlp) + assert np.isclose(y.std().detach().numpy(), std, rtol=1e-2) + + z = nn.Parameter(torch.empty(d_model * 123)) + utils.init_kaiming_normal_(z, nonlinearity=nonlinearity) + std = gain # bias has fan_in 1 + assert np.isclose(z.std().detach().numpy(), std, rtol=1e-2) + + torch.manual_seed(1234) + x_new = nn.Parameter(torch.empty(2, d_model, 137)) + utils.init_kaiming_normal_(x_new, nonlinearity=nonlinearity) + assert torch.allclose(x_new, x, rtol=1e-2) + + +class TestInitXavier: + """Test cases for Xavier init. Std of distribution should be scaled to sqrt(2/(fan_in + fan_out)).""" + + @pytest.mark.parametrize("d_model", [4096, 10_000]) + @pytest.mark.parametrize("d_mlp", [256, 512]) + def test_init_xavier_uniform(self, d_model, d_mlp): + """Test init_xavier_uniform function in the utils module on 3/2/1D tensors.""" + torch.manual_seed(1234) + + x = nn.Parameter(torch.empty(2, d_model, 137)) + utils.init_xavier_uniform_(x) + std = np.sqrt(2 / (d_model + 137 * 2)) + assert np.isclose(x.std().detach().numpy(), std, rtol=1e-2) + # for uniform distributions, min/max is sqrt(3) times the std + assert np.isclose(x.max().detach().numpy(), np.sqrt(3) * std, rtol=1e-2) + assert np.isclose(x.min().detach().numpy(), -np.sqrt(3) * std, rtol=1e-2) + + y = nn.Parameter(torch.empty(d_mlp, d_model)) + utils.init_xavier_uniform_(y) + std = np.sqrt(2 / (d_mlp + d_model)) + assert np.isclose(y.std().detach().numpy(), std, rtol=1e-2) + # for uniform distributions, min/max is sqrt(3) times the std + assert np.isclose(y.max().detach().numpy(), np.sqrt(3) * std, rtol=1e-2) + assert np.isclose(y.min().detach().numpy(), -np.sqrt(3) * std, rtol=1e-2) + + z = nn.Parameter(torch.empty(d_model * 123)) + utils.init_xavier_uniform_(z) + std = np.sqrt(2 / (1 + d_model * 123)) + assert np.isclose(z.std().detach().numpy(), std, rtol=1e-2) + # for uniform distributions, min/max is sqrt(3) times the std + assert np.isclose(z.max().detach().numpy(), np.sqrt(3) * std, rtol=1e-2) + assert np.isclose(z.min().detach().numpy(), -np.sqrt(3) * std, rtol=1e-2) + + torch.manual_seed(1234) + x_new = nn.Parameter(torch.empty(2, d_model, 137)) + utils.init_xavier_uniform_(x_new) + assert torch.allclose(x_new, x, rtol=1e-2) + + @pytest.mark.parametrize("d_model", [4096, 10_000]) + @pytest.mark.parametrize("d_mlp", [256, 512]) + def test_init_xavier_normal(self, d_model, d_mlp): + """Test init_xavier_normal function in the utils module on 3/2/1D tensors.""" + torch.manual_seed(1234) + + x = nn.Parameter(torch.empty(2, d_model, 137)) + utils.init_xavier_normal_(x) + std = np.sqrt(2 / (d_model + 137 * 2)) + assert np.isclose(x.std().detach().numpy(), std, rtol=1e-2) + + y = nn.Parameter(torch.empty(d_mlp, d_model)) + utils.init_xavier_normal_(y) + std = np.sqrt(2 / (d_mlp + d_model)) + assert np.isclose(y.std().detach().numpy(), std, rtol=1e-2) + + z = nn.Parameter( + torch.empty(d_model * 123) + ) # need to make this larger so std is accurate + utils.init_xavier_normal_(z) + std = np.sqrt(2 / (1 + d_model * 123)) + assert np.isclose(z.std().detach().numpy(), std, rtol=1e-2) + + torch.manual_seed(1234) + x_new = nn.Parameter(torch.empty(2, d_model, 137)) + utils.init_xavier_normal_(x_new) + assert torch.allclose(x_new, x, rtol=1e-2) diff --git a/transformer_lens/HookedEncoder.py b/transformer_lens/HookedEncoder.py index 99cdc1cdd..03469ee6e 100644 --- a/transformer_lens/HookedEncoder.py +++ b/transformer_lens/HookedEncoder.py @@ -270,6 +270,9 @@ def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: @property def b_U(self) -> Float[torch.Tensor, "d_vocab"]: + """ + Convenience to get the unembedding bias + """ return self.unembed.b_U @property @@ -379,13 +382,17 @@ def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: @property def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] + """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head. + Useful for visualizing attention patterns.""" return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) @property def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] + """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head.""" return FactoredMatrix(self.W_V, self.W_O) def all_head_labels(self) -> List[str]: + """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" return [ f"L{l}H{h}" for l in range(self.cfg.n_layers) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index eb6fd6ba9..1d717a970 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -43,7 +43,13 @@ # generation. from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCache from transformer_lens.utilities import devices -from transformer_lens.utils import USE_DEFAULT_VALUE +from transformer_lens.utils import ( + USE_DEFAULT_VALUE, + init_kaiming_normal_, + init_kaiming_uniform_, + init_xavier_normal_, + init_xavier_uniform_, +) SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor LossPerToken = Float[torch.Tensor, "batch pos-1"] @@ -113,7 +119,7 @@ def __init__( "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a " "pretrained model, use HookedTransformer.from_pretrained() instead." ) - self.cfg = cfg + self.cfg: HookedTransformerConfig = cfg if tokenizer is not None: self.set_tokenizer(tokenizer, default_padding_side=default_padding_side) @@ -121,7 +127,8 @@ def __init__( # If we have a tokenizer name, we can load it from HuggingFace if self.cfg.tokenizer_name in NON_HF_HOSTED_MODEL_NAMES: logging.warning( - f"{self.cfg.tokenizer_name} tokenizer not loaded. Please load manually." + "%s tokenizer not loaded. Please load manually.", + self.cfg.tokenizer_name, ) else: # Hugging Face defaults to use_fast to True @@ -188,7 +195,7 @@ def __init__( pass else: logging.warning( - f"Invalid normalization_type passed in {self.cfg.normalization_type}" + "Invalid normalization_type passed in %s", self.cfg.normalization_type ) self.unembed = Unembed(self.cfg) @@ -525,7 +532,7 @@ def forward( tokens that have already been through the model. Also caches attention_mask so previous tokens are masked correctly (unless frozen). Padding should be ignored in all cases, so it's okay to eg. pass in left padded tokens twice in a row. - Warning: Don't accidently prepend_bos to the second half of a prompt. + Warning: Don't accidentally prepend_bos to the second half of a prompt. Defaults to None (don't use caching). """ @@ -982,7 +989,7 @@ def get_token_position( indices = torch.arange(len(tokens), device=tokens.device)[ tokens == single_token ] - assert len(indices) > 0, f"The token does not occur in the prompt" + assert len(indices) > 0, "The token does not occur in the prompt" if mode == "first": return indices[0].item() elif mode == "last": @@ -1360,10 +1367,6 @@ def from_pretrained_no_processing( def init_weights(self): """Initialize weights. - Initialize weights matrices with a normal of std=initializer_range (default=0.02). This - roughly follows the GPT-2 paper's scheme (but with truncation, and not halving the std for - W_pos). - LayerNorm weights are already initialized to 1.0, and all biases are initialized to 0.0 (including LayerNorm), so this just initializes weight matrices. @@ -1376,20 +1379,120 @@ def init_weights(self): This does NOT follow the PyTorch scheme, which as far as I can tell is super out of date but no one has gotten round to updating it? https://github.com/pytorch/pytorch/issues/18182 + The default PyTorch scheme is the following: all linear layers use uniform(-1/sqrt(fan_in), + 1/sqrt(fan_in)) for weights, and uniform(-1/sqrt(fan_in), 1/sqrt(fan_in)) for biases. For + biases, fan_in is computed using the fan_in for the weight matrix of the linear layer. Note + tha it *does not actually* use Kaiming initialization, despite the fact that it calls the + function. + + However, for Transformer blocks, it instead initializes biases to zero and weights using Xavier uniform, that + is: uniform(-sqrt(6 / (fan_in + fan_out)), sqrt(6 / (fan_in + fan_out))) for weights. + PyTorch Transformers are especially bad - TransformerEncoder initializes all layers to the - exact same weights?! https://github.com/pytorch/pytorch/issues/72253 + exact same weights?! https://github.com/pytorch/pytorch/issues/72253. The best paper I've found on transformer initialization is the muP paper, but haven't integrated those ideas yet: https://arxiv.org/abs/2203.03466 + + We split off the initialization into separate functions because muP initialization handles + different parts of the model differently. """ if self.cfg.seed is not None: torch.manual_seed(self.cfg.seed) + if self.cfg.init_mode == "gpt2": + self._init_weights_gpt2() + elif self.cfg.init_mode == "xavier_uniform": + self._init_weights_xavier(dist_type="uniform") + elif self.cfg.init_mode == "xavier_normal": + self._init_weights_xavier(dist_type="normal") + elif self.cfg.init_mode == "kaiming_uniform": + self._init_weights_kaiming(dist_type="uniform") + elif self.cfg.init_mode == "kaiming_normal": + self._init_weights_kaiming(dist_type="normal") + elif self.cfg.init_mode == "muP": + self._init_weights_muP(dist_type="normal") # muP uses normal initialization + + def _init_weights_gpt2(self): + """Initialize weights with GPT-2 initialization. Biases are initialized to 0.0 and weights + are initialized to N(0, 0.64/d_model) if initializer_range is not set, otherwise std is initializer_range. + """ for name, param in self.named_parameters(): if "W_" in name: nn.init.normal_(param, std=self.cfg.initializer_range) + def _init_weights_xavier(self, dist_type="normal"): + """ + Initialize weights with Xavier initialization -- that is, scale the weights by sqrt(6 / + (fan_in + fan_out)) for a [-1, 1] uniform distribution, or sqrt(2 / (fan_in + fan_out)) for a + standard normal. + + Note that since TransformerLens implements the matrices in the opposite orientation to what + torch does (e.g. it's d_in x d_out, not d_out x d_in as in torch), we need to calculate it + ourselves. + """ + gain = self.cfg.initializer_range + for name, param in self.named_parameters(): + if "W_" in name: + if dist_type == "uniform": + init_xavier_uniform_(param, gain=gain) + elif dist_type == "normal": + init_xavier_normal_(param, gain=gain) + + def _init_weights_kaiming(self, dist_type="uniform"): + """ + Initialize weights with Kaiming initialization -- that is, scale the weights by + c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for + everything else. + + Note that the numbers are actually incorrect here when you're using a nonlinearity other + than relu, e.g. the correct c for SiLu is ~1.74, for tanh it's 5/3 ~= 1.67, and for GeLU it's ~1.57. + But this is unlikely to matter in practice. + + I'm just using fan_mode = "fan_in" for now, but it should be trivial to add fan_out. + + Again, we have to implement it ourselves because of the orientation of the matrices. + """ + gain = self.cfg.initializer_range + for name, param in self.named_parameters(): + if "W_" in name: + if dist_type == "uniform": + init_kaiming_uniform_( + param, gain=gain, nonlinearity="relu", mode="fan_in" + ) + elif dist_type == "normal": + init_kaiming_normal_( + param, gain=gain, nonlinearity="relu", mode="fan_in" + ) + + def _init_weights_muP(self, dist_type="uniform"): + """ + Initialize weights with muParameterization. This involves scaling output weights by a factor + of 1/fan_in, input weights and biases by 1, everything else by a factor of 1/sqrt(fan_in). + + Also, you need to use muAdamW, which rescales the learning rate for output weights and + hidden weights by a factor of 1/fan_in. + + All biases are still assumed to be initialized to 0.0, so we only need to change the + weights. + """ + for name, param in self.named_parameters(): + if "W_" in name: + fan_in, _ = utils.calc_fan_in_and_fan_out(param) + if "embed" in name: + scale = 1 + elif "unembed" in name: + scale = 1 / fan_in + else: + scale = 1 / fan_in**0.5 + + if dist_type == "uniform": + scale *= 3**0.5 + nn.init.uniform_(param, -scale, scale) + elif dist_type == "normal": + nn.init.normal_(param, std=scale) + def load_and_process_state_dict( self, state_dict: Dict[str, torch.Tensor], @@ -2284,6 +2387,7 @@ def all_composition_scores( return scores def all_head_labels(self): + """Returns a list of all head names in the model.""" return [ f"L{l}H{h}" for l in range(self.cfg.n_layers) diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index 501f6e881..09ed1b166 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -78,9 +78,8 @@ class HookedTransformerConfig: local attention weight_init_mode (str): the initialization mode to use for the weights. Only relevant for custom models, ignored for pre-trained. - Currently the only supported mode is 'gpt2', where biases are - initialized to 0 and weights are standard normals of range - initializer_range. + We now support 'gpt2', 'xavier_uniform', 'xavier_normal', 'kaiming_uniform', + 'kaiming_normal'. MuP support to come. Defaults to 'gpt2'. normalization_type (str, *optional*): the type of normalization to use. Options are None (no normalization), 'LN' (use LayerNorm, including weights & biases) and 'LNPre' (use LayerNorm, but no weights & biases). @@ -98,7 +97,9 @@ class HookedTransformerConfig: Used to set sources of randomness (Python, PyTorch and NumPy) and to initialize weights. Defaults to None. We recommend setting a seed, so your experiments are reproducible. initializer_range (float): The standard deviation of the normal used to - initialise the weights, initialized to 0.8 / sqrt(d_model) . + initialise the weights, initialized to 0.8 / sqrt(d_model). If weight_init_mode is + 'xavier_uniform' or 'xavier_normal', this value is instead treated as the `gain` parameter for the weight + initialisation (a constant factor to scale the weights by). Defaults to -1.0, which means not set. init_weights (bool): Whether to initialize the weights. Defaults to True. If False, does not initialize weights. scale_attn_by_inverse_layer_idx (bool): Whether to scale the attention @@ -209,8 +210,14 @@ def __post_init__(self): self.n_heads = self.d_model // self.d_head if not self.d_model % (self.d_head) == 0: + # logging.warning( + # f"d_model {self.d_model} is not divisible by d_head {self.d_head}. n_heads was inferred to be {self.n_heads}, rounding down the ratio." + # ) logging.warning( - f"d_model {self.d_model} is not divisible by d_head {self.d_head}. n_heads was inferred to be {self.n_heads}, rounding down the ratio." + "d_model %d is not divisible by d_head %d. n_heads was inferred to be %d, rounding down the ratio.", + self.d_model, + self.d_head, + self.n_heads, ) if self.seed is not None: @@ -225,16 +232,19 @@ def __post_init__(self): if not self.attn_only: if self.d_mlp is None: # For some reason everyone hard codes in this hyper-parameter! - self.d_mlp = self.d_model * 4 + self.d_mlp: int = self.d_model * 4 assert ( self.act_fn is not None ), "act_fn must be specified for non-attn-only models" assert ( self.act_fn in SUPPORTED_ACTIVATIONS ), f"act_fn={self.act_fn} must be one of {SUPPORTED_ACTIVATIONS}" - if self.initializer_range < 0: + if self.initializer_range < 0 and self.init_mode == "gpt2": # Roughly copy the GPT-2 value, but proportional to sqrt(1/d_model) self.initializer_range = 0.8 / np.sqrt(self.d_model) + if self.initializer_range < 0 and self.init_mode != "gpt2": + # This is the gain parameter for the weight initialisation + self.initializer_range = 1.0 if self.d_vocab_out == -1: # d_vocab_out defaults to d_vocab, unless there's an algorithmic task diff --git a/transformer_lens/components.py b/transformer_lens/components.py index cdda30ccd..d2fe5ed0a 100644 --- a/transformer_lens/components.py +++ b/transformer_lens/components.py @@ -399,7 +399,7 @@ def __init__( Args: cfg (Union[Dict, HookedTransformerConfig]): Config attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global". - layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. + layer_id (int, optional): The index of the current layer. Used by the Mistral models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. """ super().__init__() if isinstance(cfg, Dict): diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index c549d6066..5c74dac26 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -15,6 +15,7 @@ import einops import numpy as np import torch +import torch.nn as nn import torch.nn.functional as F import transformers from datasets.arrow_dataset import Dataset @@ -198,6 +199,83 @@ def solu( return input * F.softmax(input, dim=-1) +def calc_fan_in_and_fan_out(tensor): + """ + Calculate the fan in and fan out of a tensor. We define it ourselves because Torch uses a + different convention for weights (e.g. for an MLP they use d_out x d_in, and we use d_in x + d_out, for attention they do (n_head d_head) x d_model, we do n_head x d_model x d_head). + """ + shape = tensor.shape + + if len(shape) == 0: + raise ValueError("Fan in and fan out can not be computed for scalars.") + elif len(shape) == 1: + fan_in = 1 + fan_out = shape[0] + elif len(shape) == 2: # Linear transform + fan_in = shape[0] + fan_out = shape[1] + elif len(shape) == 3: # Attention head weight, has shape n_head x d_model x d_head + fan_in = shape[1] + fan_out = shape[0] * shape[2] + else: + raise ValueError( + f"Fan in and fan out can not be computed for shape {shape} tensors." + ) + + return fan_in, fan_out + + +def init_xavier_uniform_(param, gain=1.0): + """ + Initializes the input tensor using the Xavier initialization method. + """ + fan_in, fan_out = calc_fan_in_and_fan_out(param) + max = gain * np.sqrt(6.0 / (fan_in + fan_out)) + return nn.init.uniform_(param, -max, max) + + +def init_xavier_normal_(param, gain=1.0): + """ + Initializes the input tensor using the Xavier initialization method. + """ + fan_in, fan_out = calc_fan_in_and_fan_out(param) + std = gain * np.sqrt(2.0 / (fan_in + fan_out)) + return nn.init.normal_(param, mean=0.0, std=std) + + +def init_kaiming_uniform_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_in"): + """ + Initializes the input tensor using the Kaiming initialization method. + + Starting from a std 1 uniform distribution, we scale the weights by c / sqrt(fan_in), where c = + sqrt(2) if the params were immediately preceded by a relu and 1 for everything else. + + As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one. + """ + fan_in, fan_out = calc_fan_in_and_fan_out(param) + fan = fan_in if mode == "fan_in" else fan_out + gain *= nn.init.calculate_gain(nonlinearity, a) + max = gain * np.sqrt(3.0 / fan) + return nn.init.uniform_(param, -max, max) + + +def init_kaiming_normal_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_in"): + """ + Initializes the input tensor using the Kaiming initialization method. + + Starting from a std 1 normal distribution, we scale the weights by c / sqrt(fan_in), where c = + sqrt(2) if the params were immediately preceded by a relu and 1 for everything else. + + As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one. + """ + fan_in, fan_out = calc_fan_in_and_fan_out(param) + fan = fan_in if mode == "fan_in" else fan_out + gain *= nn.init.calculate_gain(nonlinearity, a) + std = gain * np.sqrt(1.0 / fan) + return nn.init.normal_(param, mean=0.0, std=std) + + def keep_single_column(dataset: Dataset, col_name: str): """ Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful when we want to tokenize and mix together different strings @@ -283,16 +361,6 @@ def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]: return tokenized_dataset -""" -Test ^ - -data = Dataset.from_dict({"text":[str(i) for i in range(1000)]}) -tokenizer = AutoTokenizer.from_pretrained("NeelNanda/gpt-neox-tokenizer-digits") -print(data) -tokenize_and_concatenate(data, tokenizer, streaming=False, column_name="text") -""" - - def sample_logits( final_logits: Float[torch.Tensor, "batch d_vocab"], top_k: Optional[int] = None, From afe79001d17a798c07a587b83fb5eefdcefcd160 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Mon, 8 Apr 2024 21:51:24 +0100 Subject: [PATCH 51/73] chore: fixing type errors and enabling mypy (#516) * chore: fixing type errors and enabling mypy * updated pyproject * fixing typing after merging updates * fixed correct typing for float --------- Co-authored-by: Bryce Meyer --- .github/workflows/checks.yml | 34 ++++----- poetry.lock | 2 +- pyproject.toml | 5 +- tests/unit/test_make_docs.py | 3 +- transformer_lens/ActivationCache.py | 62 ++++++++------- transformer_lens/FactoredMatrix.py | 10 +-- transformer_lens/HookedEncoder.py | 11 ++- transformer_lens/HookedTransformer.py | 85 +++++++++++---------- transformer_lens/HookedTransformerConfig.py | 2 + transformer_lens/SVDInterpreter.py | 20 ++--- transformer_lens/components.py | 33 ++++++-- transformer_lens/head_detector.py | 4 +- transformer_lens/hook_points.py | 43 +++++++---- transformer_lens/loading_from_pretrained.py | 39 ++++++---- transformer_lens/patching.py | 45 ++++++++++- transformer_lens/train.py | 12 ++- transformer_lens/utils.py | 39 +++++----- 17 files changed, 282 insertions(+), 167 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 1a4e35ccd..09b11e966 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -5,24 +5,24 @@ on: branches: - main paths: - - '**' # Include all files by default - - '!.devcontainer/**' - - '!.vscode/**' - - '!.git*' - - '!*.md' - - '!.github/**' - - '.github/workflows/checks.yml' # Still include current workflow + - "**" # Include all files by default + - "!.devcontainer/**" + - "!.vscode/**" + - "!.git*" + - "!*.md" + - "!.github/**" + - ".github/workflows/checks.yml" # Still include current workflow pull_request: branches: - main paths: - - '**' - - '!.devcontainer/**' - - '!.vscode/**' - - '!.git*' - - '!*.md' - - '!.github/**' - - '.github/workflows/checks.yml' + - "**" + - "!.devcontainer/**" + - "!.vscode/**" + - "!.git*" + - "!*.md" + - "!.github/**" + - ".github/workflows/checks.yml" # Allow this workflow to be called from other workflows workflow_call: inputs: @@ -73,11 +73,11 @@ jobs: run: make unit-test - name: Docstring test run: make docstring-test - # - name: Type check - # run: poetry run mypy transformer_lens + - name: Type check + run: poetry run mypy . - name: Build check run: poetry build - + # Acceptance tests are run in parallel with unit checks. acceptance-tests: name: Acceptance Tests diff --git a/poetry.lock b/poetry.lock index 0d13a6dd7..6aba9fa7e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5325,4 +5325,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "8128709b406f2fd78fa6fc77f68878861e7d95fc62acf1e0fc80f46b1026ba56" +content-hash = "0bc401f271115fc5955dccb3cca0d29c981bc204a8894378723ca738b4d0287e" diff --git a/pyproject.toml b/pyproject.toml index 0efa3822f..815641f29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ [tool.poetry.dependencies] accelerate=">=0.23.0" # Needed for Llama Models beartype="^0.14.1" + better-abc="^0.0.3" datasets=">=2.7.1" einops=">=0.6.0" fancy-einsum=">=0.0.3" @@ -35,7 +36,6 @@ transformers=">=4.37.2" typing-extensions="*" wandb=">=0.13.5" - better-abc = "^0.0.3" sentencepiece = "*" [tool.poetry.group] @@ -44,7 +44,7 @@ circuitsvis=">=1.38.1" isort="5.8.0" jupyter=">=1.0.0" - mypy=">=0.991" + mypy=">=1.8.0" nbval="^0.10.0" plotly=">=5.12.0" pycln="^2.1.3" @@ -91,6 +91,7 @@ [tool.mypy] check_untyped_defs=true + exclude=[".venv/", "assets", "demos", "docs", "easy_transformer", "tests"] ignore_missing_imports=true [tool.black] diff --git a/tests/unit/test_make_docs.py b/tests/unit/test_make_docs.py index fdb421962..47f5afe96 100644 --- a/tests/unit/test_make_docs.py +++ b/tests/unit/test_make_docs.py @@ -1,8 +1,9 @@ """Make Docs Tests.""" + import pytest from docs.make_docs import get_config, get_property -from transformer_lens import HookedTransformerConfig +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig def test_get_config(): diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index 83b5b4110..31906fade 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -15,7 +15,7 @@ class first, including the examples, and then skimming the available methods. Yo import logging import warnings -from typing import Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast import einops import numpy as np @@ -277,7 +277,7 @@ def items(self): """ return self.cache_dict.items() - def __iter__(self) -> Iterator[Tuple[str, torch.Tensor]]: + def __iter__(self) -> Iterator[str]: """ActivationCache Iterator. Special method that returns an iterator over the ActivationCache. Allows looping over the @@ -315,6 +315,7 @@ def apply_slice_to_batch_dim( """ if not isinstance(batch_slice, Slice): batch_slice = Slice(batch_slice) + batch_slice = cast(Slice, batch_slice) # mypy can't seem to infer this assert ( self.has_batch_dim or batch_slice.mode == "empty" ), "Cannot index into a cache without a batch dim" @@ -330,11 +331,11 @@ def apply_slice_to_batch_dim( def accumulated_resid( self, layer: Optional[int] = None, - incl_mid: Optional[bool] = False, - apply_ln: Optional[bool] = False, + incl_mid: bool = False, + apply_ln: bool = False, pos_slice: Optional[Union[Slice, SliceInput]] = None, - mlp_input: Optional[bool] = False, - return_labels: Optional[bool] = False, + mlp_input: bool = False, + return_labels: bool = False, ) -> Union[ Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], Tuple[ @@ -439,19 +440,21 @@ def accumulated_resid( layer = self.model.cfg.n_layers assert isinstance(layer, int) labels = [] - components = [] + components_list = [] for l in range(layer + 1): if l == self.model.cfg.n_layers: - components.append(self[("resid_post", self.model.cfg.n_layers - 1)]) + components_list.append( + self[("resid_post", self.model.cfg.n_layers - 1)] + ) labels.append("final_post") continue - components.append(self[("resid_pre", l)]) + components_list.append(self[("resid_pre", l)]) labels.append(f"{l}_pre") if (incl_mid and l < layer) or (mlp_input and l == layer): - components.append(self[("resid_mid", l)]) + components_list.append(self[("resid_mid", l)]) labels.append(f"{l}_mid") - components = [pos_slice.apply(c, dim=-2) for c in components] - components = torch.stack(components, dim=0) + components_list = [pos_slice.apply(c, dim=-2) for c in components_list] + components = torch.stack(components_list, dim=0) if apply_ln: components = self.apply_ln_to_stack( components, layer, pos_slice=pos_slice, mlp_input=mlp_input @@ -633,6 +636,7 @@ def decompose_resid( """ if not isinstance(pos_slice, Slice): pos_slice = Slice(pos_slice) + pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this if layer is None or layer == -1: # Default to the residual stream immediately pre unembed layer = self.model.cfg.n_layers @@ -640,28 +644,28 @@ def decompose_resid( incl_attn = mode != "mlp" incl_mlp = mode != "attn" and not self.model.cfg.attn_only - components = [] + components_list = [] labels = [] if incl_embeds: if self.has_embed: - components = [self["hook_embed"]] + components_list = [self["hook_embed"]] labels.append("embed") if self.has_pos_embed: - components.append(self["hook_pos_embed"]) + components_list.append(self["hook_pos_embed"]) labels.append("pos_embed") for l in range(layer): if incl_attn: - components.append(self[("attn_out", l)]) + components_list.append(self[("attn_out", l)]) labels.append(f"{l}_attn_out") if incl_mlp: - components.append(self[("mlp_out", l)]) + components_list.append(self[("mlp_out", l)]) labels.append(f"{l}_mlp_out") if mlp_input and incl_attn: - components.append(self[("attn_out", layer)]) + components_list.append(self[("attn_out", layer)]) labels.append(f"{layer}_attn_out") - components = [pos_slice.apply(c, dim=-2) for c in components] - components = torch.stack(components, dim=0) + components_list = [pos_slice.apply(c, dim=-2) for c in components_list] + components = torch.stack(components_list, dim=0) if apply_ln: components = self.apply_ln_to_stack( components, layer, pos_slice=pos_slice, mlp_input=mlp_input @@ -725,6 +729,7 @@ def stack_head_results( """ if not isinstance(pos_slice, Slice): pos_slice = Slice(pos_slice) + pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this if layer is None or layer == -1: # Default to the residual stream immediately pre unembed layer = self.model.cfg.n_layers @@ -735,7 +740,7 @@ def stack_head_results( ) self.compute_head_results() - components = [] + components: Any = [] labels = [] for l in range(layer): # Note that this has shape batch x pos x head_index x d_model @@ -771,7 +776,7 @@ def stack_head_results( components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice) if return_labels: - return components, labels + return components, labels # type: ignore # TODO: fix this properly else: return components @@ -830,9 +835,9 @@ def get_neuron_results( Returns: Tensor of the results. """ - if type(neuron_slice) is not Slice: + if not isinstance(neuron_slice, Slice): neuron_slice = Slice(neuron_slice) - if type(pos_slice) is not Slice: + if not isinstance(pos_slice, Slice): pos_slice = Slice(pos_slice) neuron_acts = self[("post", layer, "mlp")] @@ -890,7 +895,7 @@ def stack_neuron_results( # Default to the residual stream immediately pre unembed layer = self.model.cfg.n_layers - components = [] + components: Any = [] # TODO: fix typing properly labels = [] if not isinstance(neuron_slice, Slice): @@ -898,7 +903,9 @@ def stack_neuron_results( if not isinstance(pos_slice, Slice): pos_slice = Slice(pos_slice) - neuron_labels = neuron_slice.apply(torch.arange(self.model.cfg.d_mlp), dim=0) + neuron_labels: torch.Tensor | np.ndarray = neuron_slice.apply( + torch.arange(self.model.cfg.d_mlp), dim=0 + ) if type(neuron_labels) == int: neuron_labels = np.array([neuron_labels]) for l in range(layer): @@ -1055,6 +1062,7 @@ def get_full_resid_decomposition( if layer is None or layer == -1: # Default to the residual stream immediately pre unembed layer = self.model.cfg.n_layers + assert layer is not None # keep mypy happy if not isinstance(pos_slice, Slice): pos_slice = Slice(pos_slice) @@ -1105,6 +1113,6 @@ def get_full_resid_decomposition( ) if return_labels: - return residual_stack, labels + return residual_stack, labels # type: ignore # TODO: fix this properly else: return residual_stack diff --git a/transformer_lens/FactoredMatrix.py b/transformer_lens/FactoredMatrix.py index 7f72a5df6..e037d45dc 100644 --- a/transformer_lens/FactoredMatrix.py +++ b/transformer_lens/FactoredMatrix.py @@ -52,7 +52,7 @@ def __matmul__( ... @overload - def __matmul__( + def __matmul__( # type: ignore self, other: Float[torch.Tensor, "rdim"], ) -> Float[torch.Tensor, "... ldim"]: @@ -83,7 +83,7 @@ def __matmul__( return (self @ other.A) @ other.B @overload - def __rmatmul__( + def __rmatmul__( # type: ignore self, other: Union[ Float[torch.Tensor, "... new_rdim ldim"], @@ -93,13 +93,13 @@ def __rmatmul__( ... @overload - def __rmatmul__( + def __rmatmul__( # type: ignore self, other: Float[torch.Tensor, "ldim"], ) -> Float[torch.Tensor, "... rdim"]: ... - def __rmatmul__( + def __rmatmul__( # type: ignore self, other: Union[ Float[torch.Tensor, "... new_rdim ldim"], @@ -131,7 +131,7 @@ def __mul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: ), f"Tensor must be a scalar for use with * but was of shape {scalar.shape}. For matrix multiplication, use @ instead." return FactoredMatrix(self.A * scalar, self.B) - def __rmul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: + def __rmul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: # type: ignore """ Right scalar multiplication. For scalar multiplication from the right, we can reuse the __mul__ method. """ diff --git a/transformer_lens/HookedEncoder.py b/transformer_lens/HookedEncoder.py index 03469ee6e..51e357788 100644 --- a/transformer_lens/HookedEncoder.py +++ b/transformer_lens/HookedEncoder.py @@ -3,6 +3,7 @@ Contains a BERT style model. This is separate from :class:`transformer_lens.HookedTransformer` because it has a significantly different architecture to e.g. GPT style transformers. """ + from __future__ import annotations import logging @@ -16,9 +17,11 @@ from typing_extensions import Literal import transformer_lens.loading_from_pretrained as loading -from transformer_lens import ActivationCache, FactoredMatrix, HookedTransformerConfig +from transformer_lens.ActivationCache import ActivationCache from transformer_lens.components import BertBlock, BertEmbed, BertMLMHead, Unembed +from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookedRootModule, HookPoint +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.utilities import devices @@ -140,7 +143,7 @@ def forward( resid = self.mlm_head(resid) if return_type is None: - return + return None logits = self.unembed(resid) return logits @@ -153,7 +156,7 @@ def run_with_cache( @overload def run_with_cache( - self, *model_args, return_cache_object: Literal[False] = False, **kwargs + self, *model_args, return_cache_object: Literal[False], **kwargs ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: ... @@ -181,7 +184,7 @@ def run_with_cache( else: return out, cache_dict - def to( + def to( # type: ignore self, device_or_dtype: Union[torch.device, str, torch.dtype], print_details: bool = True, diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 1d717a970..a06e6178b 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -8,8 +8,9 @@ alteration of activations in individual components like attention heads and MLP layers, facilitating a deeper understanding of the internal workings of transformers like GPT-2. """ + import logging -from typing import Dict, List, NamedTuple, Optional, Tuple, Union, overload +from typing import Dict, List, NamedTuple, Optional, Tuple, Union, cast, overload import einops import numpy as np @@ -23,7 +24,6 @@ import transformer_lens.loading_from_pretrained as loading import transformer_lens.utils as utils -from transformer_lens import HookedTransformerConfig from transformer_lens.ActivationCache import ActivationCache from transformer_lens.components import ( Embed, @@ -37,6 +37,7 @@ ) from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookedRootModule, HookPoint +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES # Note - activation cache is used with run_with_cache, past_key_value_caching is used for @@ -89,6 +90,8 @@ class HookedTransformer(HookedRootModule): investigating. This can be done with :func:`transformer_lens.utils.test_prompt`. """ + ln_final: nn.Module + def __init__( self, cfg: Union[HookedTransformerConfig, Dict], @@ -271,7 +274,7 @@ def input_to_embed( past_kv_cache (HookedTransformerKeyValueCache, optional): If passed, we're doing caching and attention_mask will be stored in the cache. """ - if type(input) == str or type(input) == list: + if isinstance(input, str) or isinstance(input, list): # If text, convert to tokens (batch_size=1) assert ( self.tokenizer is not None @@ -370,7 +373,7 @@ def forward( self, input, return_type: Literal["logits"], - loss_per_token: Optional[bool] = False, + loss_per_token: bool = False, prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, padding_side: Optional[ Union[Literal["left", "right"], None] @@ -391,7 +394,7 @@ def forward( self, input, return_type: Literal["loss"], - loss_per_token: Optional[bool] = False, + loss_per_token: bool = False, prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, padding_side: Optional[ Union[Literal["left", "right"], None] @@ -412,7 +415,7 @@ def forward( self, input, return_type: Literal["both"], - loss_per_token: Optional[bool] = False, + loss_per_token: bool = False, prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, padding_side: Optional[ Union[Literal["left", "right"], None] @@ -433,7 +436,7 @@ def forward( self, input, return_type: Literal[None], - loss_per_token: Optional[bool] = False, + loss_per_token: bool = False, prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, padding_side: Optional[ Union[Literal["left", "right"], None] @@ -458,7 +461,7 @@ def forward( Float[torch.Tensor, "batch pos d_model"], ], return_type: Optional[str] = "logits", - loss_per_token: Optional[bool] = False, + loss_per_token: bool = False, prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, start_at_layer: Optional[int] = None, @@ -577,9 +580,9 @@ def forward( residual, # Cache contains a list of HookedTransformerKeyValueCache objects, one for each # block - past_kv_cache_entry=past_kv_cache[i] - if past_kv_cache is not None - else None, + past_kv_cache_entry=( + past_kv_cache[i] if past_kv_cache is not None else None + ), shortformer_pos_embed=shortformer_pos_embed, attention_mask=attention_mask, ) # [batch, pos, d_model] @@ -631,7 +634,7 @@ def run_with_cache( @overload def run_with_cache( - self, *model_args, return_cache_object: Literal[False] = False, **kwargs + self, *model_args, return_cache_object: Literal[False], **kwargs ) -> Tuple[Output, Dict[str, torch.Tensor]]: ... @@ -691,6 +694,7 @@ def set_tokenizer( # (https://github.com/huggingface/transformers/issues/25886). tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer) self.tokenizer = tokenizer_with_bos + assert self.tokenizer is not None # keep mypy happy self.tokenizer.padding_side = default_padding_side # Some tokenizers doesn't automatically prepend the BOS token even when they are initialized @@ -717,8 +721,8 @@ def to_tokens( padding_side: Optional[ Union[Literal["left", "right"], None] ] = USE_DEFAULT_VALUE, - move_to_device: Optional[bool] = True, - truncate: Optional[bool] = True, + move_to_device: bool = True, + truncate: bool = True, ) -> Int[torch.Tensor, "batch pos"]: """Converts a string to a tensor of tokens. @@ -865,6 +869,8 @@ def to_str_tokens( with utils.LocallyOverridenDefaults( self, prepend_bos=prepend_bos, padding_side=padding_side ): + assert self.tokenizer is not None # keep mypy happy + tokens: Union[np.ndarray, torch.Tensor] if isinstance(input, list): return list( map( @@ -901,7 +907,6 @@ def to_str_tokens( ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" else: raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}") - str_tokens = self.tokenizer.batch_decode( tokens, clean_up_tokenization_spaces=False ) @@ -924,7 +929,7 @@ def to_single_str_token(self, int_token: int) -> str: assert isinstance(int_token, int) token = self.to_str_tokens(torch.tensor([int_token])) assert len(token) == 1 - return token[0] + return cast(str, token[0]) def get_token_position( self, @@ -1054,10 +1059,10 @@ def tokens_to_residual_directions( residual_direction = self.W_U[:, token] return residual_direction - def to( + def to( # type: ignore self, device_or_dtype: Union[torch.device, str, torch.dtype], - print_details: Optional[bool] = True, + print_details: bool = True, ): return devices.move_to_and_update_config(self, device_or_dtype, print_details) @@ -1093,20 +1098,20 @@ def move_model_modules_to_device(self): def from_pretrained( cls, model_name: str, - fold_ln: Optional[bool] = True, - center_writing_weights: Optional[bool] = True, - center_unembed: Optional[bool] = True, - refactor_factored_attn_matrices: Optional[bool] = False, + fold_ln: bool = True, + center_writing_weights: bool = True, + center_unembed: bool = True, + refactor_factored_attn_matrices: bool = False, checkpoint_index: Optional[int] = None, checkpoint_value: Optional[int] = None, hf_model: Optional[AutoModelForCausalLM] = None, device: Optional[Union[str, torch.device]] = None, - n_devices: Optional[int] = 1, + n_devices: int = 1, tokenizer: Optional[PreTrainedTokenizerBase] = None, - move_to_device: Optional[bool] = True, - fold_value_biases: Optional[bool] = True, - default_prepend_bos: Optional[bool] = True, - default_padding_side: Optional[Literal["left", "right"]] = "right", + move_to_device: bool = True, + fold_value_biases: bool = True, + default_prepend_bos: bool = True, + default_padding_side: Literal["left", "right"] = "right", dtype="float32", **from_pretrained_kwargs, ) -> "HookedTransformer": @@ -1481,7 +1486,7 @@ def _init_weights_muP(self, dist_type="uniform"): if "W_" in name: fan_in, _ = utils.calc_fan_in_and_fan_out(param) if "embed" in name: - scale = 1 + scale = float(1) elif "unembed" in name: scale = 1 / fan_in else: @@ -1496,11 +1501,11 @@ def _init_weights_muP(self, dist_type="uniform"): def load_and_process_state_dict( self, state_dict: Dict[str, torch.Tensor], - fold_ln: Optional[bool] = True, - center_writing_weights: Optional[bool] = True, - center_unembed: Optional[bool] = True, - fold_value_biases: Optional[bool] = True, - refactor_factored_attn_matrices: Optional[bool] = False, + fold_ln: bool = True, + center_writing_weights: bool = True, + center_unembed: bool = True, + fold_value_biases: bool = True, + refactor_factored_attn_matrices: bool = False, ): """Load & Process State Dict. @@ -1695,7 +1700,7 @@ def fold_layer_norm( "mean", ) - if self.cfg.act_fn.startswith("solu"): + if self.cfg.act_fn is not None and self.cfg.act_fn.startswith("solu"): # Fold ln3 into activation if fold_biases: state_dict[f"blocks.{l}.mlp.b_out"] = state_dict[ @@ -1971,7 +1976,7 @@ def process_weights_( for layer in self.blocks: layer.ln1 = LayerNormPre(self.cfg) layer.ln2 = LayerNormPre(self.cfg) - if self.cfg.act_fn.endswith("_ln"): + if self.cfg.act_fn is not None and self.cfg.act_fn.endswith("_ln"): layer.mlp.ln = LayerNormPre(self.cfg) elif fold_ln and self.cfg.normalization_type == "RMS": # We do the same for RMSNorm if used @@ -1980,7 +1985,7 @@ def process_weights_( for layer in self.blocks: layer.ln1 = RMSNormPre(self.cfg) layer.ln2 = RMSNormPre(self.cfg) - if self.cfg.act_fn.endswith("_ln"): + if self.cfg.act_fn is not None and self.cfg.act_fn.endswith("_ln"): layer.mlp.ln = RMSNormPre(self.cfg) self.load_and_process_state_dict( @@ -2090,8 +2095,9 @@ def generate( else: past_kv_cache = None - stop_tokens = [] + stop_tokens: List[int] = [] eos_token_for_padding = 0 + assert self.tokenizer is not None if stop_at_eos: tokenizer_has_eos_token = ( self.tokenizer is not None @@ -2261,7 +2267,7 @@ def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: return torch.stack([block.mlp.W_in for block in self.blocks], dim=0) @property - def W_gate(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: + def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]: """Stack the MLP gate weights across all layers. Only works for models with gated MLPs. @@ -2437,7 +2443,7 @@ def load_sample_training_dataset(self, **kwargs): def sample_datapoint( self, - tokenize: Optional[bool] = False, + tokenize: bool = False, prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, ) -> Union[str, Float[torch.Tensor, "1 pos"]]: @@ -2464,6 +2470,7 @@ def sample_datapoint( """ if self.dataset is None: self.load_sample_training_dataset() + assert self.dataset is not None # keep mypy happy sample_dataset_size = len(self.dataset) index = np.random.randint(0, sample_dataset_size) if not tokenize: diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index 09ed1b166..bb284a8b6 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -3,6 +3,7 @@ Module with a dataclass for storing the configuration of a :class:`transformer_lens.HookedTransformer` model. """ + from __future__ import annotations import logging @@ -259,6 +260,7 @@ def __post_init__(self): (self.d_model * self.d_head * self.n_heads * 4) ) if not self.attn_only: + assert self.d_mlp is not None # mypy # Number of parameters in MLP layers (ignoring biases and layer norm). 2 because W_in and W_out self.n_params += self.n_layers * self.d_model * self.d_mlp * 2 diff --git a/transformer_lens/SVDInterpreter.py b/transformer_lens/SVDInterpreter.py index caaecd448..c34c34e6f 100644 --- a/transformer_lens/SVDInterpreter.py +++ b/transformer_lens/SVDInterpreter.py @@ -3,6 +3,7 @@ Module for getting the singular vectors of the OV, w_in, and w_out matrices of a :class:`transformer_lens.HookedTransformer`. """ + from typing import Optional, Union import fancy_einsum as einsum @@ -10,7 +11,8 @@ from typeguard import typechecked from typing_extensions import Literal -from transformer_lens import FactoredMatrix, HookedTransformer +from transformer_lens.FactoredMatrix import FactoredMatrix +from transformer_lens.HookedTransformer import HookedTransformer OUTPUT_EMBEDDING = "unembed.W_U" VECTOR_TYPES = ["OV", "w_in", "w_out"] @@ -79,7 +81,9 @@ def plot_matrix(matrix, tokens, k=10, filter="topk"): "w_out", ], f"Head index optional only for w_in and w_out, got {vector_type}" + matrix: Union[FactoredMatrix, torch.Tensor] if vector_type == "OV": + assert head_index is not None # keep mypy happy matrix = self._get_OV_matrix(layer_index, head_index) V = matrix.Vh.T @@ -108,12 +112,12 @@ def _get_singular_vectors_from_matrix( ) -> torch.Tensor: """Returns the top num_vectors singular vectors from a matrix.""" - vectors = [] + vectors_list = [] for i in range(num_vectors): - activations = V[i, :].float() @ embedding - vectors.append(activations) + activations = V[i, :].float() @ embedding # type: ignore + vectors_list.append(activations) - vectors = torch.stack(vectors, dim=1).unsqueeze(1) + vectors = torch.stack(vectors_list, dim=1).unsqueeze(1) assert vectors.shape == ( self.cfg.d_vocab, 1, @@ -131,10 +135,8 @@ def _get_OV_matrix(self, layer_index: int, head_index: int) -> FactoredMatrix: 0 <= head_index < self.cfg.n_heads ), f"Head index must be between 0 and {self.cfg.n_heads-1} but got {head_index}" - W_V, W_O = ( - self.params[f"blocks.{layer_index}.attn.W_V"], - self.params[f"blocks.{layer_index}.attn.W_O"], - ) + W_V: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_V"] + W_O: torch.Tensor = self.params[f"blocks.{layer_index}.attn.W_O"] W_V, W_O = W_V[head_index, :, :], W_O[head_index, :, :] return FactoredMatrix(W_V, W_O) diff --git a/transformer_lens/components.py b/transformer_lens/components.py index d2fe5ed0a..452a31eea 100644 --- a/transformer_lens/components.py +++ b/transformer_lens/components.py @@ -4,9 +4,10 @@ needed to create many different types of generative language models. They are used by :class:`transformer_lens.HookedTransformer`. """ + import logging from abc import ABC -from typing import Dict, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union import einops import numpy as np @@ -263,7 +264,7 @@ def forward( if self.cfg.dtype not in [torch.float32, torch.float64]: x = x.to(torch.float32) - x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, length] + x = x - x.mean(-1, keepdim=True) # [batch, pos, length] scale: Union[ Float[torch.Tensor, "batch pos 1"], Float[torch.Tensor, "batch pos head_index 1"], @@ -311,7 +312,7 @@ def forward( if self.cfg.dtype not in [torch.float32, torch.float64]: x = x.to(torch.float32) - x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, length] + x = x - x.mean(-1, keepdim=True) # [batch, pos, length] scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() ) @@ -384,6 +385,8 @@ def forward( class AbstractAttention(ABC, nn.Module): + alibi: Union[torch.Tensor, None] + def __init__( self, cfg: Union[Dict, HookedTransformerConfig], @@ -450,6 +453,7 @@ def __init__( else: self.attn_scale = 1.0 if self.cfg.scale_attn_by_inverse_layer_idx: + assert self.layer_id is not None # keep mypy happy self.attn_scale *= self.layer_id + 1 self.hook_k = HookPoint() # [batch, pos, head_index, d_head] @@ -468,6 +472,7 @@ def __init__( # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details self.hook_rot_k = HookPoint() self.hook_rot_q = HookPoint() + assert self.cfg.rotary_dim is not None # keep mypy happy sin, cos = self.calculate_sin_cos_rotary( self.cfg.rotary_dim, self.cfg.n_ctx, @@ -817,7 +822,7 @@ def apply_rotary( @staticmethod def create_alibi_slope( - n_ctx: int, device: torch.device = None + n_ctx: int, device: Optional[Union[str, torch.device]] = None ) -> Float[torch.Tensor, "query key"]: """Create an ALiBi Slope Matrix. @@ -859,7 +864,7 @@ def create_alibi_slope( @staticmethod def create_alibi_multipliers( - n_heads: int, device: torch.device = None + n_heads: int, device: Optional[Union[str, torch.device]] = None ) -> Float[torch.Tensor, "head_idx"]: """Create the ALiBi Scalar Multipliers for each Head. @@ -898,7 +903,7 @@ def create_alibi_multipliers( @staticmethod def create_alibi_bias( - n_heads: int, n_ctx: int, device: torch.device = None + n_heads: int, n_ctx: int, device: Optional[Union[torch.device, str]] = None ) -> Float[torch.Tensor, "head_idx query key"]: """Create the ALiBi Bias for all Heads. @@ -1167,11 +1172,15 @@ def calculate_z_scores( # MLP Layers class MLP(nn.Module): + act_fn: Callable[..., torch.Tensor] + ln: nn.Module + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): super().__init__() if isinstance(cfg, Dict): cfg = HookedTransformerConfig.from_dict(cfg) self.cfg = cfg + assert self.cfg.d_mlp is not None # TODO: should this not be optional? self.W_in = nn.Parameter( torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype) ) @@ -1214,7 +1223,7 @@ def forward( einsum("batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in) + self.b_in ) # [batch, pos, d_mlp] - if not self.cfg.act_fn.endswith("_ln"): + if self.cfg.act_fn is not None and not self.cfg.act_fn.endswith("_ln"): post_act = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp] else: mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] @@ -1242,11 +1251,15 @@ class GatedMLP(nn.Module): In one equation, mlp_out = (Gelu(x @ W_gate) * (x @ W_in) + b_in) @ W_out + b_out """ + act_fn: Callable[..., torch.Tensor] + ln: nn.Module + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): super().__init__() if isinstance(cfg, Dict): cfg = HookedTransformerConfig.from_dict(cfg) self.cfg = cfg + assert self.cfg.d_mlp is not None # keep mypy happy self.W_in = nn.Parameter( torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype) ) @@ -1297,7 +1310,7 @@ def forward( "batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_gate ) ) # [batch, pos, d_mlp] - if not self.cfg.act_fn.endswith("_ln"): + if self.cfg.act_fn is not None and not self.cfg.act_fn.endswith("_ln"): pre_linear = self.hook_pre_linear( einsum( "batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in @@ -1321,6 +1334,10 @@ def forward( # Transformer Block class TransformerBlock(nn.Module): + ln1: nn.Module + ln2: nn.Module + mlp: nn.Module + def __init__(self, cfg: Union[Dict, HookedTransformerConfig], block_index): super().__init__() if isinstance(cfg, Dict): diff --git a/transformer_lens/head_detector.py b/transformer_lens/head_detector.py index 41eb72da9..13964c358 100644 --- a/transformer_lens/head_detector.py +++ b/transformer_lens/head_detector.py @@ -2,6 +2,7 @@ Utilities for detecting specific types of heads (e.g. previous token heads). """ + import logging from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union, cast @@ -10,7 +11,8 @@ import torch from typing_extensions import Literal, get_args -from transformer_lens import ActivationCache, HookedTransformer +from transformer_lens.ActivationCache import ActivationCache +from transformer_lens.HookedTransformer import HookedTransformer from transformer_lens.utils import is_lower_triangular, is_square HeadName = Literal["previous_token_head", "duplicate_token_head", "induction_head"] diff --git a/transformer_lens/hook_points.py b/transformer_lens/hook_points.py index 313c6d4d6..831ee11b1 100644 --- a/transformer_lens/hook_points.py +++ b/transformer_lens/hook_points.py @@ -2,11 +2,12 @@ Helpers to access activations in models. """ + import logging from contextlib import contextmanager from dataclasses import dataclass from functools import partial -from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union, cast import torch.nn as nn import torch.utils.hooks as hooks @@ -71,12 +72,12 @@ def full_hook(module, module_input, module_output): hook.__repr__() ) # annotate the `full_hook` with the string representation of the `hook` function - handle = self.register_forward_hook(full_hook) - handle = LensHandle(handle, is_permanent, level) + pt_handle = self.register_forward_hook(full_hook) + handle = LensHandle(pt_handle, is_permanent, level) if prepend: # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this... - self._forward_hooks.move_to_end(handle.hook.id, last=False) + self._forward_hooks.move_to_end(handle.hook.id, last=False) # type: ignore # TODO: this type error could signify a bug self.fwd_hooks.insert(0, handle) else: @@ -92,12 +93,12 @@ def full_hook(module, module_input, module_output): hook.__repr__() ) # annotate the `full_hook` with the string representation of the `hook` function - handle = self.register_full_backward_hook(full_hook) - handle = LensHandle(handle, is_permanent, level) + pt_handle = self.register_full_backward_hook(full_hook) + handle = LensHandle(pt_handle, is_permanent, level) if prepend: # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this... - self._backward_hooks.move_to_end(handle.hook.id, last=False) + self._backward_hooks.move_to_end(handle.hook.id, last=False) # type: ignore # TODO: this type error could signify a bug self.bwd_hooks.insert(0, handle) else: self.bwd_hooks.append(handle) @@ -136,6 +137,7 @@ def layer(self): # Returns the layer index if the name has the form 'blocks.{layer}.{...}' # Helper function that's mainly useful on HookedTransformer # If it doesn't have this form, raises an error - + assert self.name is not None # keep mypy happy split_name = self.name.split(".") return int(split_name[1]) @@ -236,7 +238,13 @@ def check_and_add_hook( ) def check_hooks_to_add( - self, hook_point, hook_point_name, hook, dir="fwd", is_permanent=False + self, + hook_point, + hook_point_name, + hook, + dir="fwd", + is_permanent=False, + prepend=False, ) -> None: """Override this function to add checks on which hooks should be added""" pass @@ -300,7 +308,7 @@ def hooks( self.context_level += 1 for name, hook in fwd_hooks: - if type(name) == str: + if isinstance(name, str): self.mod_dict[name].add_hook( hook, dir="fwd", level=self.context_level ) @@ -310,13 +318,13 @@ def hooks( if name(hook_name): hp.add_hook(hook, dir="fwd", level=self.context_level) for name, hook in bwd_hooks: - if type(name) == str: + if isinstance(name, str): self.mod_dict[name].add_hook( hook, dir="bwd", level=self.context_level ) else: # Otherwise, name is a Boolean function on names - for hook_name, hp in self.hook_dict: + for hook_name, hp in self.hook_dict: # type: ignore if name(hook_name): hp.add_hook(hook, dir="bwd", level=self.context_level) yield self @@ -399,6 +407,9 @@ def add_caching_hooks( filter_list = names_filter names_filter = lambda name: name in filter_list + # mypy can't seem to infer this + names_filter = cast(Callable[[str], bool], names_filter) + self.is_caching = True def save_hook(tensor, hook): @@ -495,7 +506,7 @@ def get_caching_hooks( device=None, remove_batch_dim: bool = False, cache: Optional[dict] = None, - pos_slice: Slice = None, + pos_slice: Optional[Slice] = None, ) -> Tuple[dict, list, list]: """Creates hooks to cache activations. Note: It does not add the hooks to the model. @@ -516,14 +527,17 @@ def get_caching_hooks( if names_filter is None: names_filter = lambda name: True - elif type(names_filter) == str: + elif isinstance(names_filter, str): filter_str = names_filter names_filter = lambda name: name == filter_str - elif type(names_filter) == list: + elif isinstance(names_filter, list): filter_list = names_filter names_filter = lambda name: name in filter_list self.is_caching = True + # mypy can't seem to infer this + names_filter = cast(Callable[[str], bool], names_filter) + def save_hook(tensor, hook, is_backward=False): hook_name = hook.name if is_backward: @@ -549,6 +563,7 @@ def save_hook(tensor, hook, is_backward=False): if ( tensor.dim() >= -pos_dim ): # check if the residual stream has a pos dimension before trying to slice + assert pos_slice is not None # keep mypy happy resid_stream = pos_slice.apply(resid_stream, dim=pos_dim) cache[hook_name] = resid_stream diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 72affb15f..6d107eca3 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -2,10 +2,11 @@ This module contains functions for loading pretrained models from the Hugging Face Hub. """ + import dataclasses import logging import re -from typing import Dict, Optional +from typing import Dict, Optional, Union, cast import einops import torch @@ -943,9 +944,11 @@ def convert_hf_model_config(model_name: str, **kwargs): "eps": hf_config.rms_norm_eps, "d_vocab": hf_config.vocab_size, "act_fn": hf_config.hidden_act, - "n_key_value_heads": hf_config.num_key_value_heads - if hf_config.num_key_value_heads != hf_config.num_attention_heads - else None, + "n_key_value_heads": ( + hf_config.num_key_value_heads + if hf_config.num_key_value_heads != hf_config.num_attention_heads + else None + ), # This is done because the current implementation of GQA will use Grouped-Query Attention if # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as # the same as hf_config.num_attention_heads, in which case GQA should not be used. @@ -1123,7 +1126,7 @@ def get_pretrained_model_config( checkpoint_index: Optional[int] = None, checkpoint_value: Optional[int] = None, fold_ln: bool = False, - device: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, n_devices: int = 1, default_prepend_bos: bool = True, dtype: torch.dtype = torch.float32, @@ -1698,10 +1701,14 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig): # the state dict keys for the K/V attention weight/biases, prepending "_" to the key names. using_gqa = cfg.n_key_value_heads is not None gqa_uscore = "_" if using_gqa else "" + # need a cast since MyPy isn't smart enough to realize that using_gqa implies n_key_value_heads is not None + n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads) # llama has no biases anywhere and deals with everything else roughly like # GPTNeoX with different names + assert cfg.d_mlp is not None # keep mypy happy + for l in range(cfg.n_layers): state_dict[f"blocks.{l}.ln1.w"] = llama.model.layers[l].input_layernorm.weight @@ -1709,12 +1716,8 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig): W_K = llama.model.layers[l].self_attn.k_proj.weight W_V = llama.model.layers[l].self_attn.v_proj.weight W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) - W_K = einops.rearrange( - W_K, "(n h) m->n m h", n=cfg.n_key_value_heads if using_gqa else cfg.n_heads - ) - W_V = einops.rearrange( - W_V, "(n h) m->n m h", n=cfg.n_key_value_heads if using_gqa else cfg.n_heads - ) + W_K = einops.rearrange(W_K, "(n h) m->n m h", n=n_kv_heads) + W_V = einops.rearrange(W_V, "(n h) m->n m h", n=n_kv_heads) state_dict[f"blocks.{l}.attn.W_Q"] = W_Q state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V @@ -1723,13 +1726,13 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig): cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device ) state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros( - cfg.n_key_value_heads if using_gqa else cfg.n_heads, + n_kv_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device, ) state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros( - cfg.n_key_value_heads if using_gqa else cfg.n_heads, + n_kv_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device, @@ -1777,6 +1780,8 @@ def convert_qwen_weights(qwen, cfg: HookedTransformerConfig): model = qwen.transformer state_dict["embed.W_E"] = model.wte.weight + assert cfg.d_mlp is not None # keep mypy happy + for l in range(cfg.n_layers): state_dict[f"blocks.{l}.ln1.w"] = model.h[l].ln_1.weight @@ -1841,6 +1846,8 @@ def convert_qwen2_weights(qwen, cfg: HookedTransformerConfig): state_dict["embed.W_E"] = qwen.model.embed_tokens.weight + assert cfg.d_mlp is not None # keep mypy happy + for l in range(cfg.n_layers): state_dict[f"blocks.{l}.ln1.w"] = qwen.model.layers[l].input_layernorm.weight @@ -1914,6 +1921,9 @@ def convert_mistral_weights(mistral, cfg: HookedTransformerConfig): state_dict["embed.W_E"] = mistral.model.embed_tokens.weight + assert cfg.n_key_value_heads is not None # keep mypy happy + assert cfg.d_mlp is not None # keep mypy happy + # Mistral has no biases anywhere for l in range(cfg.n_layers): state_dict[f"blocks.{l}.ln1.w"] = mistral.model.layers[l].input_layernorm.weight @@ -2509,6 +2519,9 @@ def convert_phi_weights(phi, cfg: HookedTransformerConfig): def convert_gemma_weights(gemma, cfg: HookedTransformerConfig): state_dict = {} + assert cfg.n_key_value_heads is not None # mypy + assert cfg.d_mlp is not None # mypy + # Gemma Models scale embeddings by multiplying by sqrt(d_model) state_dict["embed.W_E"] = gemma.model.embed_tokens.weight * (cfg.d_model**0.5) diff --git a/transformer_lens/patching.py b/transformer_lens/patching.py index b97a95191..189cfda2e 100644 --- a/transformer_lens/patching.py +++ b/transformer_lens/patching.py @@ -51,7 +51,7 @@ import itertools from functools import partial -from typing import Callable, Optional, Sequence, Tuple, Union +from typing import Callable, Optional, Sequence, Tuple, Union, overload import einops import pandas as pd @@ -61,7 +61,8 @@ from typing_extensions import Literal import transformer_lens.utils as utils -from transformer_lens import ActivationCache, HookedTransformer +from transformer_lens.ActivationCache import ActivationCache +from transformer_lens.HookedTransformer import HookedTransformer # %% Logits = torch.Tensor @@ -92,6 +93,44 @@ def make_df_from_ranges( PatchedActivation = torch.Tensor +@overload +def generic_activation_patch( + model: HookedTransformer, + corrupted_tokens: Int[torch.Tensor, "batch pos"], + clean_cache: ActivationCache, + patching_metric: Callable[ + [Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""] + ], + patch_setter: Callable[ + [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation + ], + activation_name: str, + index_axis_names: Optional[Sequence[AxisNames]] = None, + index_df: Optional[pd.DataFrame] = None, + return_index_df: Literal[False] = False, +) -> torch.Tensor: + ... + + +@overload +def generic_activation_patch( + model: HookedTransformer, + corrupted_tokens: Int[torch.Tensor, "batch pos"], + clean_cache: ActivationCache, + patching_metric: Callable[ + [Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""] + ], + patch_setter: Callable[ + [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation + ], + activation_name: str, + index_axis_names: Optional[Sequence[AxisNames]], + index_df: Optional[pd.DataFrame], + return_index_df: Literal[True], +) -> Tuple[torch.Tensor, pd.DataFrame]: + ... + + def generic_activation_patch( model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, "batch pos"], @@ -643,7 +682,7 @@ def get_act_patch_attn_head_all_pos_every( Returns: patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [5, n_layers, n_heads] """ - act_patch_results = [] + act_patch_results: list[torch.Tensor] = [] act_patch_results.append( get_act_patch_attn_head_out_all_pos( model, corrupted_tokens, clean_cache, metric diff --git a/transformer_lens/train.py b/transformer_lens/train.py index 379b34ecc..946295619 100644 --- a/transformer_lens/train.py +++ b/transformer_lens/train.py @@ -3,16 +3,19 @@ Utilities for training :class:`transformer_lens.HookedTransformer` models on autoregressive language modeling tasks. """ + from dataclasses import dataclass from typing import Optional import torch import torch.optim as optim import wandb +from torch.optim import Optimizer from torch.utils.data import DataLoader, Dataset from tqdm.auto import tqdm -from transformer_lens import HookedTransformer, utils +from transformer_lens import utils +from transformer_lens.HookedTransformer import HookedTransformer @dataclass @@ -81,6 +84,7 @@ def train( if config.device is None: config.device = utils.get_device() + optimizer: Optimizer if config.optimizer_name in ["Adam", "AdamW"]: # Weight decay in Adam is implemented badly, so use AdamW instead (see PyTorch AdamW docs) if config.weight_decay is not None: @@ -98,9 +102,9 @@ def train( optimizer = optim.SGD( model.parameters(), lr=config.lr, - weight_decay=config.weight_decay - if config.weight_decay is not None - else 0.0, + weight_decay=( + config.weight_decay if config.weight_decay is not None else 0.0 + ), momentum=config.momentum, ) else: diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 5c74dac26..3a343a7e4 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -10,7 +10,7 @@ import re import shutil from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast import einops import numpy as np @@ -25,7 +25,7 @@ from rich import print as rprint from transformers import AutoTokenizer -from transformer_lens import FactoredMatrix +from transformer_lens.FactoredMatrix import FactoredMatrix CACHE_DIR = transformers.TRANSFORMERS_CACHE USE_DEFAULT_VALUE = None @@ -426,7 +426,7 @@ def sample_logits( # Type alias -SliceInput: Type = Optional[ +SliceInput = Optional[ Union[ int, Tuple[int,], @@ -473,6 +473,8 @@ class Slice: elif input_slice is a Tensor, same as list - Tensor is assumed to be a 1D list of indices. """ + slice: Union[int, slice, np.ndarray] + def __init__( self, input_slice: SliceInput = None, @@ -486,14 +488,13 @@ def __init__( Raises: ValueError: If the input_slice is not one of the above types. """ - if type(input_slice) == tuple: - input_slice: slice = slice(*input_slice) - self.slice = input_slice + if isinstance(input_slice, tuple): + self.slice = slice(*input_slice) self.mode = "slice" - elif type(input_slice) == int: + elif isinstance(input_slice, int): self.slice = input_slice self.mode = "int" - elif type(input_slice) == slice: + elif isinstance(input_slice, slice): self.slice = input_slice self.mode = "slice" elif type(input_slice) in [list, torch.Tensor, np.ndarray]: @@ -522,7 +523,7 @@ def apply( """ ndim = tensor.ndim slices = [slice(None)] * ndim - slices[dim] = self.slice + slices[dim] = self.slice # type: ignore return tensor[tuple(slices)] def indices( @@ -600,7 +601,7 @@ def get_act_name( return name match = re.match(r"([a-z]+)(\d+)([a-z]?.*)", name) if match is not None: - name, layer, layer_type = match.groups(0) + name, layer, layer_type = match.groups(0) # type: ignore layer_type_alias = { "a": "attn", @@ -672,10 +673,10 @@ def test_prompt( prompt: str, answer: str, model, # Can't give type hint due to circular imports - prepend_space_to_answer: Optional[bool] = True, - print_details: Optional[bool] = True, + prepend_space_to_answer: bool = True, + print_details: bool = True, prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, - top_k: Optional[int] = 10, + top_k: int = 10, ) -> None: """Test if the Model Can Give the Correct Answer to a Prompt. @@ -804,11 +805,11 @@ def composition_scores( left.rdim == right.ldim ), f"Composition scores require left.rdim==right.ldim, shapes were left: {left.shape}, right:{right.shape}" - right = right.collapse_r() - left = left.collapse_l() - r_norms = right.norm(dim=[-2, -1]) - l_norms = left.norm(dim=[-2, -1]) - comp_norms = (left @ right).norm(dim=[-2, -1]) + new_right = right.collapse_r() + new_left = left.collapse_l() + r_norms = new_right.norm(dim=[-2, -1]) + l_norms = new_left.norm(dim=[-2, -1]) + comp_norms = (new_left @ new_right).norm(dim=[-2, -1]) return comp_norms / r_norms / l_norms @@ -1096,7 +1097,7 @@ def __enter__(self): # Ensure the override is a valid value valid_values = info["valid_values"] assert ( - override in valid_values + override in valid_values # type: ignore ), f"{property} must be one of {valid_values}, but got {override}." # Fetch current default and store it to restore later From 760135a27c4b7873b0cb66aca541958f8939f60b Mon Sep 17 00:00:00 2001 From: Collin Date: Mon, 8 Apr 2024 14:38:04 -0700 Subject: [PATCH 52/73] Add Mixtral (#521) * add moe config options * bump transformers version, needed for hf mixtral * add architecture config * add moe component, no hooks yet * add convert_mixtral_weights * formatting * fix convert_mixtral_weights * fixes * rename moe state_dict names * add multi-gpu fixes by @coolvision * fix einsum * fix moe forward pass * cap mixtral context, model working * disable ln folding for moe (for now) * update htconfig docstring with moe options * formatting * add benchmarker to test_hooked_transformer * add moe gate and chosen expert hooks * formatting * add moe dtype warning * add special cases page to docs * formatting * fix missing .cfg * fix doc heading level, add desc. to moe hook points * fix formatting * fix new mypy errors * fix mypy issues for real this time * rename moe gate hook names --------- Co-authored-by: Bryce Meyer --- docs/source/content/special_cases.md | 11 ++ docs/source/index.md | 1 + tests/acceptance/test_hooked_transformer.py | 158 +++++++++++++++++++- transformer_lens/HookedTransformer.py | 26 +++- transformer_lens/HookedTransformerConfig.py | 24 ++- transformer_lens/components.py | 70 ++++++++- transformer_lens/loading_from_pretrained.py | 112 +++++++++++++- transformer_lens/past_key_value_caching.py | 1 + 8 files changed, 394 insertions(+), 9 deletions(-) create mode 100644 docs/source/content/special_cases.md diff --git a/docs/source/content/special_cases.md b/docs/source/content/special_cases.md new file mode 100644 index 000000000..a5eae2164 --- /dev/null +++ b/docs/source/content/special_cases.md @@ -0,0 +1,11 @@ +# Special Cases + +## Mixture of Experts error rates +Due to the Top-K gating performed in the hidden layer of Mixture of Experts models, small errors can be amplified +greatly in cases where a different expert is selected, which leads to a higher than normal variance in the error rate +of the final logits. In testing done on Mixtral running in half precision, the standard deviation of the absolute error +rate of the logits compared to those from the default model was found to be around 2e-3. + +There are two main ways to mitigate this: +1. Disable preprocessing options by using `HookedTransformer.from_pretrained_no_processing` instead of `HookedTransformer.from_pretrained` +2. Increase the precision of the data type used in the model diff --git a/docs/source/index.md b/docs/source/index.md index 09b00f20f..869ebc248 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -44,6 +44,7 @@ content/citation content/contributing generated/demos/Main_Demo generated/demos/Exploratory_Analysis_Demo +content/special_cases ``` ```{toctree} diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index b267a3278..180717da8 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -1,13 +1,18 @@ import gc import os +import pandas as pd import pytest import torch -from transformers import AutoConfig, AutoModelForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformer_lens import HookedTransformer from transformer_lens.components import LayerNormPre -from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES +from transformer_lens.HookedTransformer import DTYPE_FROM_STRING +from transformer_lens.loading_from_pretrained import ( + OFFICIAL_MODEL_NAMES, + get_official_model_name, +) from transformer_lens.utils import clear_huggingface_cache TINY_STORIES_MODEL_NAMES = [ @@ -245,6 +250,155 @@ def check_norm_folding( ) +def calculate_error(logits1, logits2): + t1 = torch.softmax(logits1, dim=-1).to("cpu") + t2 = torch.softmax(logits2, dim=-1).to("cpu") + err = torch.abs(t1 - t2) + return { + "max": torch.max(err).item(), + "mean": torch.mean(err).item(), + "median": torch.median(err).item(), + "std": torch.std(err).item(), + } + + +def benchmark_model_options( + model_name: str, + hf_model=None, + tokenizer=None, + device="cuda", + n_devices=1, + dtype=torch.float16, + cache_in_cpu=True, +): + options = { + "fold_ln": False, + "center_writing_weights": False, + "center_unembed": False, + "fold_value_biases": False, + } + + prompts = [ + "Hello, world!", + "This is a test.", + "What is it about?", + "I don't know.", + ] + + model_name = get_official_model_name(model_name) + + if hf_model is None: + hf_model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=dtype, device_map="auto" + ) + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(model_name) + + tokens = tokenizer( + prompts, return_tensors="pt", truncation=True, max_length=4 + ).input_ids.to(device) + + # hf_model = hf_model.to(device) + hf_logits = hf_model(tokens).logits.detach() + hf_logits = hf_logits.to("cpu") + + if cache_in_cpu: + hf_model = hf_model.to("cpu") + else: + del hf_model + hf_model = None + + torch.cuda.empty_cache() + gc.collect() + + results = {} + + # Check the error when all processing options are disabled + tl_model = HookedTransformer.from_pretrained( + model_name, + hf_model=hf_model, + tokenizer=tokenizer, + device=device, + n_devices=n_devices, + dtype=dtype, + **options, + ) + tl_logits = tl_model(tokens).detach().to("cpu") + results["no_options"] = calculate_error(hf_logits, tl_logits) + del tl_model, tl_logits + torch.cuda.empty_cache() + + # Check the error when each processing option is enabled individually + for option in options: + gc.collect() + new_options = options.copy() + new_options[option] = True + tl_model = HookedTransformer.from_pretrained( + model_name, + hf_model=hf_model, + tokenizer=tokenizer, + device=device, + n_devices=n_devices, + dtype=dtype, + **new_options, + ) + tl_logits = tl_model(tokens).detach().to("cpu") + results[option] = calculate_error(hf_logits, tl_logits) + + del tl_model, tl_logits + torch.cuda.empty_cache() + gc.collect() + + # Check the error when all processing options are enabled + all_options = {k: True for k, v in options.items()} + tl_model = HookedTransformer.from_pretrained( + model_name, + hf_model=hf_model, + tokenizer=tokenizer, + device=device, + n_devices=n_devices, + dtype=dtype, + **all_options, + ) + tl_logits = tl_model(tokens).detach().to("cpu") + results["all_options"] = calculate_error(hf_logits, tl_logits) + + del tl_model, tl_logits + + del hf_model + del tokens + gc.collect() + torch.cuda.empty_cache() + + return results + + +def benchmark_models(models, device="cuda", n_devices=1, cache_in_cpu=True): + """ + Benchmark the error introduced by different options and data types for a list of models. + :param models: A dict mapping model names to a list of dtypes to test + """ + rows = [] + + for model in models: + dtypes = models[model] + for dtype in dtypes: + print(f"Testing {model} with dtype {dtype}") + results = benchmark_model_options( + model, + device=device, + n_devices=n_devices, + dtype=DTYPE_FROM_STRING[dtype], + cache_in_cpu=cache_in_cpu, + ) + for option, result in results.items(): + rows.append( + {"model": model, "dtype": dtype, "options": option, **result} + ) + + return pd.DataFrame(rows) + + def check_similarity_with_hf_model(tl_model, hf_model, prompt="Hello, world!"): """ Check that the TransformerLens model and the HuggingFace model diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index a06e6178b..d0655c835 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1540,9 +1540,22 @@ def load_and_process_state_dict( "With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`." ) + if ( + self.cfg.dtype not in [torch.float32, torch.float64] + and self.cfg.num_experts + and self.cfg.num_experts > 1 + ): + logging.warning( + "When running MoE models, it is advised to use a higher precision data type. See docs for more info." + ) + state_dict = self.fill_missing_keys(state_dict) if fold_ln: - if self.cfg.normalization_type in ["LN", "LNPre"]: + if self.cfg.num_experts and self.cfg.num_experts > 1: + logging.warning( + "You are using MoE, so the layer norm weights can't be folded! Skipping" + ) + elif self.cfg.normalization_type in ["LN", "LNPre"]: state_dict = self.fold_layer_norm(state_dict) elif self.cfg.normalization_type in ["RMS", "RMSPre"]: state_dict = self.fold_layer_norm( @@ -1967,7 +1980,11 @@ def process_weights_( version of the same model. """ state_dict = self.state_dict() - if fold_ln and self.cfg.normalization_type == "LN": + if fold_ln and self.cfg.num_experts and self.cfg.num_experts > 1: + # If we're using MoE, we don't fold the layer norm weights, so we don't need to do any preprocessing + # A warning is already issued in `load_and_process_state_dict` + pass + elif fold_ln and self.cfg.normalization_type == "LN": # If we're folding the LN into the weights, we need to replace all the layernorm layers # with LayerNormPres, which do not have learnable parameters. This is somewhat hacky, # but it's the easiest way to do it. @@ -2183,7 +2200,10 @@ def generate( # instead. sampled_tokens[finished_sequences] = eos_token_for_padding finished_sequences.logical_or_( - torch.isin(sampled_tokens, torch.tensor(stop_tokens).to(device)) + torch.isin( + sampled_tokens.to(self.cfg.device), + torch.tensor(stop_tokens).to(self.cfg.device), + ) ) tokens = torch.cat([tokens, sampled_tokens.unsqueeze(-1)], dim=-1) diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index bb284a8b6..2ea815f0d 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -153,6 +153,10 @@ class HookedTransformerConfig: Only for models that use Grouped Query Attention. post_embedding_ln (bool): Whether to apply layer normalization after embedding the tokens. Defaults to False. + num_experts (int, *optional*): The number of experts to use in the MoE layer. If set, experts_per_token + must also be set. Set to None if not using MoE. + experts_per_token (int, *optional*): The number of experts to use for each pass in the MoE layer. If set, + num_experts must also be set. Set to None if not using MoE. """ n_layers: int @@ -205,6 +209,8 @@ class HookedTransformerConfig: rotary_base: int = 10000 trust_remote_code: bool = False rotary_adjacent_pairs: bool = False + num_experts: Optional[int] = None + experts_per_token: Optional[int] = None def __post_init__(self): if self.n_heads == -1: @@ -255,6 +261,15 @@ def __post_init__(self): if self.positional_embedding_type == "rotary" and self.rotary_dim is None: self.rotary_dim = self.d_head + if self.num_experts is not None: + assert ( + self.experts_per_token is not None + ), "experts_per_token must be set if num_experts is set" + if self.experts_per_token is not None: + assert ( + self.num_experts is not None + ), "num_experts must be set if experts_per_token is set" + # The number of parameters in attention layers (ignoring biases and layer norm). 4 because W_Q, W_K, W_V and W_O self.n_params = self.n_layers * ( (self.d_model * self.d_head * self.n_heads * 4) @@ -262,7 +277,14 @@ def __post_init__(self): if not self.attn_only: assert self.d_mlp is not None # mypy # Number of parameters in MLP layers (ignoring biases and layer norm). 2 because W_in and W_out - self.n_params += self.n_layers * self.d_model * self.d_mlp * 2 + mlp_params_per_layer = self.d_model * self.d_mlp * (2 + self.gated_mlp) + + if self.num_experts: + # If we are using MoE, we multiply by num_experts, and add the expert gate parameters (d_model * num_experts) + mlp_params_per_layer = ( + mlp_params_per_layer + self.d_model + ) * self.num_experts + self.n_params += self.n_layers * mlp_params_per_layer if self.device is None: self.device = utils.get_device() diff --git a/transformer_lens/components.py b/transformer_lens/components.py index 452a31eea..dc72da6b1 100644 --- a/transformer_lens/components.py +++ b/transformer_lens/components.py @@ -591,6 +591,7 @@ def forward( pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] pattern = pattern.to(self.cfg.dtype) + pattern = pattern.to(v.device) z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] if not self.cfg.use_attn_result: out = ( @@ -736,8 +737,10 @@ def apply_causal_mask( if attention_mask is not None: # Apply a causal mask to the attention scores considering the padding einsum_str = "batch head pos offset_pos, batch offset_pos -> batch head pos offset_pos" + final_mask = final_mask.to(attention_mask.device) final_mask = einops.einsum(final_mask, attention_mask, einsum_str).bool() + attn_scores = attn_scores.to(final_mask.device) return torch.where(final_mask, attn_scores, self.IGNORE) def calculate_sin_cos_rotary( @@ -814,6 +817,7 @@ def apply_rotary( offset_position_ids = get_offset_position_ids( past_kv_pos_offset, attention_mask ) + offset_position_ids = offset_position_ids.to(self.rotary_cos.device) mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :] mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :] x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin @@ -1332,6 +1336,68 @@ def forward( ) +class MoE(nn.Module): + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + + # Ensure that num_experts and experts_per_token are specified and non-zero + assert ( + cfg.num_experts is not None + ), "num_experts must be specified for MoE layer" + assert ( + cfg.experts_per_token + ), "experts_per_token must be specified for MoE layer" + self.experts_per_token: int = cfg.experts_per_token + assert ( + cfg.experts_per_token <= cfg.num_experts + ), "experts_per_token must be less than or equal to num_experts" + + self.experts = nn.ModuleList( + [ + GatedMLP(cfg) if cfg.gated_mlp else MLP(cfg) + for _ in range(cfg.num_experts) + ] + ) + self.W_gate = nn.Parameter( + torch.empty(cfg.d_model, cfg.num_experts, dtype=cfg.dtype) + ) + + # Hook on the weights of selected experts [batch pos experts_per_token] + self.hook_expert_weights = HookPoint() + # Hook on the indices of selected experts [batch pos experts_per_token] + self.hook_expert_indices = HookPoint() + + def forward( + self, x: Float[torch.Tensor, "batch pos d_model"] + ) -> Float[torch.Tensor, "batch pos d_model"]: + # [batch, pos, d_model] -> [batch, pos, num_experts] + gate_logits = einsum( + "batch pos d_model, d_model num_experts -> batch pos num_experts", + x, + self.W_gate, + ) + + # choose the top k(=experts_per_token) experts to use + # both are [batch, pos, experts_per_token] + weights, expert_indices = torch.topk(gate_logits, self.experts_per_token) + weights = self.hook_expert_weights(F.softmax(weights, dim=-1)) + expert_indices = self.hook_expert_indices(expert_indices) + + results = torch.zeros_like(x) + for i, expert_mlp in enumerate(self.experts): + # find the batch, pos, and expert indices which use this expert + batch, pos, expert = torch.where(expert_indices == i) + # accumulate the weighted outputs from the expert + results[batch] += weights[batch, pos, expert, None, None] * expert_mlp( + x[batch] + ) + + return results + + # Transformer Block class TransformerBlock(nn.Module): ln1: nn.Module @@ -1379,7 +1445,9 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig], block_index): attn_type = self.cfg.attn_types[block_index] self.attn = attention(cfg, attn_type, block_index) if not self.cfg.attn_only: - if self.cfg.gated_mlp: + if self.cfg.num_experts: + self.mlp = MoE(cfg) + elif self.cfg.gated_mlp: self.mlp = GatedMLP(cfg) else: self.mlp = MLP(cfg) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 6d107eca3..e417dd79c 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -143,6 +143,8 @@ "stabilityai/stablelm-tuned-alpha-7b", "mistralai/Mistral-7B-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", + "mistralai/Mixtral-8x7B-v0.1", + "mistralai/Mixtral-8x7B-Instruct-v0.1", "bigscience/bloom-560m", "bigscience/bloom-1b1", "bigscience/bloom-1b7", @@ -551,6 +553,11 @@ ], "mistralai/Mistral-7B-v0.1": ["mistral-7b"], "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"], + "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"], + "mistralai/Mixtral-8x7B-Instruct-v0.1": [ + "mixtral-instruct", + "mixtral-8x7b-instruct", + ], "bigscience/bloom-560m": ["bloom-560m"], "bigscience/bloom-1b1": ["bloom-1b1"], "bigscience/bloom-1b7": ["bloom-1b7"], @@ -650,8 +657,6 @@ def convert_hf_model_config(model_name: str, **kwargs): # Load HuggingFace model config if "llama" in official_model_name.lower(): architecture = "LlamaForCausalLM" - elif "mistral" in official_model_name.lower(): - architecture = "MistralForCausalLM" elif "gemma" in official_model_name.lower(): architecture = "GemmaForCausalLM" else: @@ -899,6 +904,28 @@ def convert_hf_model_config(model_name: str, **kwargs): "use_local_attn": True, "rotary_dim": 4096 // 32, } + elif architecture == "MixtralForCausalLM": + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": 2048, # hf_config.max_position_embeddings, # Capped due to memory issues + "d_vocab": hf_config.vocab_size, + "act_fn": hf_config.hidden_act, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "window_size": hf_config.sliding_window, # This is None, as no sliding window was used + "attn_types": ["global"] * 32, + "eps": hf_config.rms_norm_eps, + "n_key_value_heads": hf_config.num_key_value_heads, + "gated_mlp": True, + "use_local_attn": False, + "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, + "num_experts": hf_config.num_local_experts, + "experts_per_token": hf_config.num_experts_per_tok, + } elif architecture == "BloomForCausalLM": cfg_dict = { "d_model": hf_config.hidden_size, @@ -1411,6 +1438,8 @@ def get_pretrained_state_dict( state_dict = convert_bert_weights(hf_model, cfg) elif cfg.original_architecture == "MistralForCausalLM": state_dict = convert_mistral_weights(hf_model, cfg) + elif cfg.original_architecture == "MixtralForCausalLM": + state_dict = convert_mixtral_weights(hf_model, cfg) elif cfg.original_architecture == "BloomForCausalLM": state_dict = convert_bloom_weights(hf_model, cfg) elif cfg.original_architecture == "GPT2LMHeadCustomModel": @@ -1979,6 +2008,85 @@ def convert_mistral_weights(mistral, cfg: HookedTransformerConfig): return state_dict +def convert_mixtral_weights(mixtral, cfg: HookedTransformerConfig): + # The same as Mistral, but with the MLP replaced with MoE + # As with Mistral, Mixtral has no biases + + state_dict = {} + + assert cfg.n_key_value_heads is not None # keep mypy happy + assert cfg.d_mlp is not None + assert cfg.num_experts is not None + + state_dict["embed.W_E"] = mixtral.model.embed_tokens.weight + + for l in range(cfg.n_layers): + state_dict[f"blocks.{l}.ln1.w"] = mixtral.model.layers[l].input_layernorm.weight + + W_Q = mixtral.model.layers[l].self_attn.q_proj.weight + W_K = mixtral.model.layers[l].self_attn.k_proj.weight + W_V = mixtral.model.layers[l].self_attn.v_proj.weight + W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads) + W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads) + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn._W_K"] = W_K + state_dict[f"blocks.{l}.attn._W_V"] = W_V + + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( + cfg.n_heads, cfg.d_head, dtype=cfg.dtype + ) + state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( + cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype + ) + state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( + cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype + ) + + W_O = mixtral.model.layers[l].self_attn.o_proj.weight + W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_O"] = W_O + + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.ln2.w"] = mixtral.model.layers[ + l + ].post_attention_layernorm.weight + + state_dict[f"blocks.{l}.mlp.W_gate"] = mixtral.model.layers[ + l + ].block_sparse_moe.gate.weight.T + + # The mapping here from wn to W_{in/out/gate} is a bit confusing: + # w1 -> W_gate + # w2 -> W_out + # w3 -> W_in + # See https://github.com/mistralai/mistral-src/blob/main/mistral/model.py#L128 for reference + for e in range(cfg.num_experts): + state_dict[f"blocks.{l}.mlp.experts.{e}.W_in"] = ( + mixtral.model.layers[l].block_sparse_moe.experts[e].w3.weight.T + ) + state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate"] = ( + mixtral.model.layers[l].block_sparse_moe.experts[e].w1.weight.T + ) + state_dict[f"blocks.{l}.mlp.experts.{e}.b_in"] = torch.zeros( + cfg.d_mlp, dtype=cfg.dtype + ) + state_dict[f"blocks.{l}.mlp.experts.{e}.W_out"] = ( + mixtral.model.layers[l].block_sparse_moe.experts[e].w2.weight.T + ) + state_dict[f"blocks.{l}.mlp.experts.{e}.b_out"] = torch.zeros( + cfg.d_model, dtype=cfg.dtype + ) + + state_dict["ln_final.w"] = mixtral.model.norm.weight.data + + state_dict["unembed.W_U"] = mixtral.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) + + return state_dict + + def convert_opt_weights(opt, cfg: HookedTransformerConfig): state_dict = {} diff --git a/transformer_lens/past_key_value_caching.py b/transformer_lens/past_key_value_caching.py index cff973191..0aca31278 100644 --- a/transformer_lens/past_key_value_caching.py +++ b/transformer_lens/past_key_value_caching.py @@ -109,6 +109,7 @@ def unfreeze(self): def append_attention_mask( self, attention_mask: Int[torch.Tensor, "batch new_tokens"] ): + attention_mask = attention_mask.to(self.previous_attention_mask.device) updated_attention_mask = torch.cat( [self.previous_attention_mask, attention_mask], dim=-1 ) From e1980061aaaf55640de9f46fa587c9fe16e387f3 Mon Sep 17 00:00:00 2001 From: Lawrence Chan Date: Thu, 11 Apr 2024 14:01:24 -0700 Subject: [PATCH 53/73] Standardize black line length to 100, in line with other project settings (#538) * Update black line length to 100 * run black with -l 100 * edit contributing.md to include new line length * add black -l 100 to .vscode for convenience * fixed merge saving error * fixed merge issue in params * ran format * ran format on tests --------- Co-authored-by: Bryce Meyer --- .vscode/cspell.json | 1 + .vscode/settings.json | 4 + docs/make_docs.py | 5 +- docs/source/content/contributing.md | 2 + pyproject.toml | 1 + tests/acceptance/test_activation_cache.py | 60 ++--- tests/acceptance/test_hook_tokens.py | 8 +- tests/acceptance/test_hooked_encoder.py | 20 +- tests/acceptance/test_hooked_transformer.py | 27 +- tests/acceptance/test_multi_gpu.py | 29 +-- .../test_tokenizer_special_tokens.py | 4 +- .../manual_checks_type_annotations.py | 4 +- .../test_multiply_by_matrix.py | 4 +- tests/unit/factored_matrix/test_properties.py | 16 +- tests/unit/test_attention_mask.py | 4 +- tests/unit/test_cache_pos_slice.py | 24 +- tests/unit/test_create_hooked_encoder.py | 4 +- tests/unit/test_grouped_query_attention.py | 4 +- tests/unit/test_head_detector.py | 40 +-- tests/unit/test_hooks.py | 10 +- tests/unit/test_kv_cache.py | 32 +-- tests/unit/test_left_padding.py | 12 +- tests/unit/test_only_tokenizer.py | 28 +- tests/unit/test_prepend_bos.py | 44 +--- tests/unit/test_start_at_layer.py | 10 +- tests/unit/test_stop_at_layer.py | 4 +- tests/unit/test_svd_interpreter.py | 12 +- tests/unit/test_tokenization_methods.py | 8 +- tests/unit/test_utils.py | 28 +- transformer_lens/ActivationCache.py | 68 ++--- transformer_lens/FactoredMatrix.py | 4 +- transformer_lens/HookedEncoder.py | 74 ++---- transformer_lens/HookedTransformer.py | 228 +++++------------ transformer_lens/HookedTransformerConfig.py | 33 +-- transformer_lens/SVDInterpreter.py | 8 +- transformer_lens/components.py | 200 ++++----------- transformer_lens/evals.py | 40 +-- transformer_lens/head_detector.py | 20 +- transformer_lens/hook_points.py | 32 +-- transformer_lens/loading_from_pretrained.py | 239 +++++------------- transformer_lens/past_key_value_caching.py | 12 +- transformer_lens/patching.py | 54 +--- transformer_lens/train.py | 8 +- transformer_lens/utilities/devices.py | 4 +- transformer_lens/utils.py | 89 ++----- 45 files changed, 407 insertions(+), 1155 deletions(-) diff --git a/.vscode/cspell.json b/.vscode/cspell.json index 19eedf858..eea97769b 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -4,6 +4,7 @@ "accum", "adrià", "aengus", + "allclose", "alonso", "arange", "argmax", diff --git a/.vscode/settings.json b/.vscode/settings.json index 2fa400667..63e6e310a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -9,6 +9,7 @@ "source.organizeImports": "explicit" }, "editor.formatOnSave": true, + "editor.rulers": [100], "evenBetterToml.formatter.allowedBlankLines": 1, "evenBetterToml.formatter.arrayAutoCollapse": true, "evenBetterToml.formatter.arrayAutoExpand": true, @@ -40,4 +41,7 @@ "rewrap.wrappingColumn": 100, "mypy.runUsingActiveInterpreter": true, "editor.defaultFormatter": "ms-python.black-formatter", + "black-formatter.args": [ + "-l 100" + ], } \ No newline at end of file diff --git a/docs/make_docs.py b/docs/make_docs.py index ea6c41fba..d15b7a124 100644 --- a/docs/make_docs.py +++ b/docs/make_docs.py @@ -93,10 +93,7 @@ def generate_model_table(_app: Optional[Any] = None): ] df = pd.DataFrame( { - name: [ - get_property(name, model_name) - for model_name in loading.DEFAULT_MODEL_ALIASES - ] + name: [get_property(name, model_name) for model_name in loading.DEFAULT_MODEL_ALIASES] for name in column_names }, index=loading.DEFAULT_MODEL_ALIASES, diff --git a/docs/source/content/contributing.md b/docs/source/content/contributing.md index 544353aa1..49bf28f99 100644 --- a/docs/source/content/contributing.md +++ b/docs/source/content/contributing.md @@ -43,6 +43,8 @@ actions. - Format all files via `make format` - Only check the formatting via `make check-format` +Note that `black` line length is set to 100 in `pyproject.toml` (instead of the default 88). + ## Documentation Please make sure to add thorough documentation for any features you add. You should do this directly diff --git a/pyproject.toml b/pyproject.toml index 815641f29..62d8ab2d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ ignore_missing_imports=true [tool.black] + line-length=100 # Set line length to 100 to match other tools # Exclude snapshot tests & .venv exclude=''' ( diff --git a/tests/acceptance/test_activation_cache.py b/tests/acceptance/test_activation_cache.py index 452d5e37c..aa52cf6b7 100644 --- a/tests/acceptance/test_activation_cache.py +++ b/tests/acceptance/test_activation_cache.py @@ -64,18 +64,12 @@ def test_logit_attrs_matches_reference_code(): _, cache = model.run_with_cache(tokens) # Get accumulated resid - accumulated_residual = cache.accumulated_resid( - layer=-1, incl_mid=True, pos_slice=-1 - ) + accumulated_residual = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1) # Get ref ave logit diffs (cribbed notebook code) answer_residual_directions = model.tokens_to_residual_directions(answer_tokens) - logit_diff_directions = ( - answer_residual_directions[:, 0] - answer_residual_directions[:, 1] - ) - scaled_residual_stack = cache.apply_ln_to_stack( - accumulated_residual, layer=-1, pos_slice=-1 - ) + logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1] + scaled_residual_stack = cache.apply_ln_to_stack(accumulated_residual, layer=-1, pos_slice=-1) ref_ave_logit_diffs = einsum( "... batch d_model, batch d_model -> ...", scaled_residual_stack, @@ -111,12 +105,8 @@ def test_logit_attrs_works_for_all_input_shapes(): # Get ref logit diffs (cribbed notebook code) answer_residual_directions = model.tokens_to_residual_directions(answer_tokens) - logit_diff_directions = ( - answer_residual_directions[:, 0] - answer_residual_directions[:, 1] - ) - scaled_residual_stack = cache.apply_ln_to_stack( - accumulated_residual, layer=-1, pos_slice=-1 - ) + logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1] + scaled_residual_stack = cache.apply_ln_to_stack(accumulated_residual, layer=-1, pos_slice=-1) ref_logit_diffs = einsum( "... d_model, ... d_model -> ...", scaled_residual_stack, logit_diff_directions ) @@ -198,9 +188,7 @@ def test_accumulated_resid_with_apply_ln(): _, cache = model.run_with_cache(tokens) # Get accumulated resid and apply ln seperately (cribbed notebook code) - accumulated_residual = cache.accumulated_resid( - layer=-1, incl_mid=True, pos_slice=-1 - ) + accumulated_residual = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1) ref_scaled_residual_stack = cache.apply_ln_to_stack( accumulated_residual, layer=-1, pos_slice=-1 ) @@ -210,9 +198,7 @@ def test_accumulated_resid_with_apply_ln(): layer=-1, incl_mid=True, pos_slice=-1, apply_ln=True ) - assert torch.isclose( - ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7 - ).all() + assert torch.isclose(ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7).all() @torch.set_grad_enabled(False) @@ -227,16 +213,12 @@ def test_decompose_resid_with_apply_ln(): # Get decomposed resid and apply ln seperately (cribbed notebook code) per_layer_residual = cache.decompose_resid(layer=-1, pos_slice=-1) - ref_scaled_residual_stack = cache.apply_ln_to_stack( - per_layer_residual, layer=-1, pos_slice=-1 - ) + ref_scaled_residual_stack = cache.apply_ln_to_stack(per_layer_residual, layer=-1, pos_slice=-1) # Get scaled_residual_stack using apply_ln parameter scaled_residual_stack = cache.decompose_resid(layer=-1, pos_slice=-1, apply_ln=True) - assert torch.isclose( - ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7 - ).all() + assert torch.isclose(ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7).all() @torch.set_grad_enabled(False) @@ -251,18 +233,12 @@ def test_stack_head_results_with_apply_ln(): # Get per head resid stack and apply ln seperately (cribbed notebook code) per_head_residual = cache.stack_head_results(layer=-1, pos_slice=-1) - ref_scaled_residual_stack = cache.apply_ln_to_stack( - per_head_residual, layer=-1, pos_slice=-1 - ) + ref_scaled_residual_stack = cache.apply_ln_to_stack(per_head_residual, layer=-1, pos_slice=-1) # Get scaled_residual_stack using apply_ln parameter - scaled_residual_stack = cache.stack_head_results( - layer=-1, pos_slice=-1, apply_ln=True - ) + scaled_residual_stack = cache.stack_head_results(layer=-1, pos_slice=-1, apply_ln=True) - assert torch.isclose( - ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7 - ).all() + assert torch.isclose(ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7).all() @torch.set_grad_enabled(False) @@ -277,15 +253,9 @@ def test_stack_neuron_results_with_apply_ln(): # Get neuron result stack and apply ln seperately neuron_result_stack = cache.stack_neuron_results(layer=-1, pos_slice=-1) - ref_scaled_residual_stack = cache.apply_ln_to_stack( - neuron_result_stack, layer=-1, pos_slice=-1 - ) + ref_scaled_residual_stack = cache.apply_ln_to_stack(neuron_result_stack, layer=-1, pos_slice=-1) # Get scaled_residual_stack using apply_ln parameter - scaled_residual_stack = cache.stack_neuron_results( - layer=-1, pos_slice=-1, apply_ln=True - ) + scaled_residual_stack = cache.stack_neuron_results(layer=-1, pos_slice=-1, apply_ln=True) - assert torch.isclose( - ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7 - ).all() + assert torch.isclose(ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7).all() diff --git a/tests/acceptance/test_hook_tokens.py b/tests/acceptance/test_hook_tokens.py index 2280daef1..73690ae1a 100644 --- a/tests/acceptance/test_hook_tokens.py +++ b/tests/acceptance/test_hook_tokens.py @@ -30,9 +30,7 @@ def test_patch_tokens(): new_first_token = model.to_single_token("Hi") # Define hook function to alter the first token - def hook_fn( - tokens: Int[t.Tensor, "batch seq"], hook: HookPoint, new_first_token: int - ): + def hook_fn(tokens: Int[t.Tensor, "batch seq"], hook: HookPoint, new_first_token: int): assert ( tokens[0, 0].item() != new_first_token ) # Need new_first_token to be different from original @@ -43,9 +41,7 @@ def hook_fn( out_from_hook = model.run_with_hooks( prompt, prepend_bos=False, - fwd_hooks=[ - ("hook_tokens", functools.partial(hook_fn, new_first_token=new_first_token)) - ], + fwd_hooks=[("hook_tokens", functools.partial(hook_fn, new_first_token=new_first_token))], ) out_direct = model(modified_prompt, prepend_bos=False) diff --git a/tests/acceptance/test_hooked_encoder.py b/tests/acceptance/test_hooked_encoder.py index 0859686ef..e8394d873 100644 --- a/tests/acceptance/test_hooked_encoder.py +++ b/tests/acceptance/test_hooked_encoder.py @@ -41,9 +41,7 @@ def test_full_model(our_bert, huggingface_bert, tokenizer): input_ids = tokenized["input_ids"] attention_mask = tokenized["attention_mask"] - huggingface_bert_out = huggingface_bert( - input_ids, attention_mask=attention_mask - ).logits + huggingface_bert_out = huggingface_bert(input_ids, attention_mask=attention_mask).logits our_bert_out = our_bert(input_ids, one_zero_attention_mask=attention_mask) assert_close(huggingface_bert_out, our_bert_out, rtol=1.3e-6, atol=4e-5) @@ -97,23 +95,17 @@ def test_bert_block(our_bert, huggingface_bert, hello_world_tokens): def test_mlm_head(our_bert, huggingface_bert, hello_world_tokens): - huggingface_bert_core_outputs = huggingface_bert.bert( - hello_world_tokens - ).last_hidden_state + huggingface_bert_core_outputs = huggingface_bert.bert(hello_world_tokens).last_hidden_state our_mlm_head_out = our_bert.mlm_head(huggingface_bert_core_outputs) our_unembed_out = our_bert.unembed(our_mlm_head_out) - huggingface_predictions_out = huggingface_bert.cls.predictions( - huggingface_bert_core_outputs - ) + huggingface_predictions_out = huggingface_bert.cls.predictions(huggingface_bert_core_outputs) assert_close(our_unembed_out, huggingface_predictions_out, rtol=1.3e-6, atol=4e-5) def test_unembed(our_bert, huggingface_bert, hello_world_tokens): - huggingface_bert_core_outputs = huggingface_bert.bert( - hello_world_tokens - ).last_hidden_state + huggingface_bert_core_outputs = huggingface_bert.bert(hello_world_tokens).last_hidden_state our_mlm_head_out = our_bert.mlm_head(huggingface_bert_core_outputs) huggingface_predictions_out = huggingface_bert.cls.predictions.transform( @@ -167,9 +159,7 @@ def test_half_precision(dtype): def test_predictions(our_bert, huggingface_bert, tokenizer): input_ids = tokenizer("The [MASK] sat on the mat", return_tensors="pt")["input_ids"] - def get_predictions( - logits: Float[torch.Tensor, "batch pos d_vocab"], positions: List[int] - ): + def get_predictions(logits: Float[torch.Tensor, "batch pos d_vocab"], positions: List[int]): logits_at_position = logits.squeeze(0)[positions] predicted_tokens = F.softmax(logits_at_position, dim=-1).argmax(dim=-1) return tokenizer.batch_decode(predicted_tokens) diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index 180717da8..9d9e2bb19 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -19,9 +19,7 @@ name for name in OFFICIAL_MODEL_NAMES if name.startswith("roneneldan/TinyStories") ] -PYTHIA_MODEL_NAMES = [ - name for name in OFFICIAL_MODEL_NAMES if name.startswith("EleutherAI/pythia") -] +PYTHIA_MODEL_NAMES = [name for name in OFFICIAL_MODEL_NAMES if name.startswith("EleutherAI/pythia")] model_names = [ "attn-only-demo", @@ -243,10 +241,7 @@ def check_norm_folding( ) return torch.max( - torch.abs( - torch.softmax(folded_logits, dim=-1) - - torch.softmax(unfolded_logits, dim=-1) - ) + torch.abs(torch.softmax(folded_logits, dim=-1) - torch.softmax(unfolded_logits, dim=-1)) ) @@ -294,9 +289,9 @@ def benchmark_model_options( if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(model_name) - tokens = tokenizer( - prompts, return_tensors="pt", truncation=True, max_length=4 - ).input_ids.to(device) + tokens = tokenizer(prompts, return_tensors="pt", truncation=True, max_length=4).input_ids.to( + device + ) # hf_model = hf_model.to(device) hf_logits = hf_model(tokens).logits.detach() @@ -392,9 +387,7 @@ def benchmark_models(models, device="cuda", n_devices=1, cache_in_cpu=True): cache_in_cpu=cache_in_cpu, ) for option, result in results.items(): - rows.append( - {"model": model, "dtype": dtype, "options": option, **result} - ) + rows.append({"model": model, "dtype": dtype, "options": option, **result}) return pd.DataFrame(rows) @@ -441,9 +434,7 @@ def check_dtype(dtype, margin, no_processing=False): for model_path in ["gpt2", "roneneldan/TinyStories-33M", "EleutherAI/pythia-70m"]: if no_processing: # For low precision, the processing is not advised. - model = HookedTransformer.from_pretrained_no_processing( - model_path, torch_dtype=dtype - ) + model = HookedTransformer.from_pretrained_no_processing(model_path, torch_dtype=dtype) else: model = HookedTransformer.from_pretrained(model_path, torch_dtype=dtype) @@ -502,9 +493,7 @@ def remove_pos_embed(z, hook): z[:] = 0.0 return z - _ = model.run_with_hooks( - "Hello, world", fwd_hooks=[("hook_pos_embed", remove_pos_embed)] - ) + _ = model.run_with_hooks("Hello, world", fwd_hooks=[("hook_pos_embed", remove_pos_embed)]) # Check that pos embed has not been permanently changed assert (model.W_pos == initial_W_pos).all() diff --git a/tests/acceptance/test_multi_gpu.py b/tests/acceptance/test_multi_gpu.py index 260344fcf..f5f082c33 100644 --- a/tests/acceptance/test_multi_gpu.py +++ b/tests/acceptance/test_multi_gpu.py @@ -19,9 +19,7 @@ def gpt2_medium_on_4_devices(): return model -@pytest.mark.skipif( - torch.cuda.device_count() < 4, reason="Requires at least 4 CUDA devices" -) +@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 CUDA devices") def test_get_device_for_block_index(gpt2_medium_on_4_devices): config = gpt2_medium_on_4_devices.cfg n_layers = config.n_layers @@ -44,19 +42,14 @@ def test_get_device_for_block_index(gpt2_medium_on_4_devices): device_override_obj = torch.device("cuda") for i in range(n_layers): expected_device = torch.device(device_override_obj.type, i // layers_per_device) - assert ( - get_device_for_block_index(i, config, device_override_obj) - == expected_device - ) + assert get_device_for_block_index(i, config, device_override_obj) == expected_device # Test when index is out of bounds # with pytest.raises(IndexError): # get_device_for_block_index(n_layers, config) -@pytest.mark.skipif( - torch.cuda.device_count() < 4, reason="Requires at least 4 CUDA devices" -) +@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 CUDA devices") @pytest.mark.parametrize("n_devices", [1, 2, 3, 4]) def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices): model_1_device = gpt2_medium_on_1_device @@ -96,9 +89,7 @@ def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices): cache_device = gpt2_cache_n_devices[f"blocks.{i}.mlp.hook_post"].device assert cache_device == expected_device - assert torch.allclose( - gpt2_logits_1_device.to("cpu"), gpt2_logits_n_devices.to("cpu") - ) + assert torch.allclose(gpt2_logits_1_device.to("cpu"), gpt2_logits_n_devices.to("cpu")) for key in gpt2_cache_1_device.keys(): assert torch.allclose( gpt2_cache_1_device[key].to("cpu"), gpt2_cache_n_devices[key].to("cpu") @@ -123,9 +114,7 @@ def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices): ) -@pytest.mark.skipif( - torch.cuda.device_count() < 2, reason="Requires at least 2 CUDA devices" -) +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 CUDA devices") def test_cache_device(): model = HookedTransformer.from_pretrained("gpt2-small", device="cuda:1") @@ -135,15 +124,11 @@ def test_cache_device(): ) logits, cache = model.run_with_cache("Hello there", device="cpu") - assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device( - torch.device("cpu") - ) + assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(torch.device("cpu")) model.to("cuda") logits, cache = model.run_with_cache("Hello there") - assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device( - logits.device - ) + assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(logits.device) def norm_device(device): diff --git a/tests/acceptance/test_tokenizer_special_tokens.py b/tests/acceptance/test_tokenizer_special_tokens.py index 5453b9085..6e4a93a98 100644 --- a/tests/acceptance/test_tokenizer_special_tokens.py +++ b/tests/acceptance/test_tokenizer_special_tokens.py @@ -27,9 +27,7 @@ def test_d_vocab_from_tokenizer(): else: tokenizer_name = loading.get_official_model_name(model_name) - model = HookedTransformer( - cfg=cfg, tokenizer=AutoTokenizer.from_pretrained(tokenizer_name) - ) + model = HookedTransformer(cfg=cfg, tokenizer=AutoTokenizer.from_pretrained(tokenizer_name)) tokens_with_bos = model.to_tokens(test_string) tokens_without_bos = model.to_tokens(test_string, prepend_bos=False) diff --git a/tests/manual_checks/manual_checks_type_annotations.py b/tests/manual_checks/manual_checks_type_annotations.py index e89c1dcc4..a532a05ac 100644 --- a/tests/manual_checks/manual_checks_type_annotations.py +++ b/tests/manual_checks/manual_checks_type_annotations.py @@ -9,9 +9,7 @@ prompt = "Hello World!" tokens = model.to_tokens(prompt, prepend_bos=False) logits_tokens = model(tokens) -logits_text: Float[torch.Tensor, "1 n_tokens d_vocab"] = model( - prompt, prepend_bos=False -) +logits_text: Float[torch.Tensor, "1 n_tokens d_vocab"] = model(prompt, prepend_bos=False) # n.b. that i used this file to see if my type annotations were working- they were! i occasionally # changed one of the sizes and saw that the type checker caught it. diff --git a/tests/unit/factored_matrix/test_multiply_by_matrix.py b/tests/unit/factored_matrix/test_multiply_by_matrix.py index 85caff1a7..91e2dca4e 100644 --- a/tests/unit/factored_matrix/test_multiply_by_matrix.py +++ b/tests/unit/factored_matrix/test_multiply_by_matrix.py @@ -45,9 +45,7 @@ def test_left_multiply_when_both_have_leading_dim(self, a, b, matrix): b_with_leading = repeat(b, "x y -> b x y", b=2) matrix_with_leading = repeat(matrix, "x y -> b x y", b=2) - product = self._test_multiply( - a_with_leading, b_with_leading, matrix_with_leading - ) + product = self._test_multiply(a_with_leading, b_with_leading, matrix_with_leading) assert product.A.shape[:-2] == (2,) assert product.B.shape[:-2] == (2,) diff --git a/tests/unit/factored_matrix/test_properties.py b/tests/unit/factored_matrix/test_properties.py index d724ea854..c18b760cf 100644 --- a/tests/unit/factored_matrix/test_properties.py +++ b/tests/unit/factored_matrix/test_properties.py @@ -66,9 +66,7 @@ def test_transpose_property(self, factored_matrices): def test_svd_property(self, factored_matrices): for factored_matrix in factored_matrices: U, S, Vh = factored_matrix.svd() - assert torch.allclose( - factored_matrix.AB, U @ torch.diag_embed(S) @ Vh.T, atol=1e-5 - ) + assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ Vh.T, atol=1e-5) # test that U and Vh are unitary assert torch.allclose(U.T @ U, torch.eye(U.shape[-1]), atol=1e-5) assert torch.allclose(Vh.T @ Vh, torch.eye(Vh.shape[-1]), atol=1e-5) @@ -76,9 +74,7 @@ def test_svd_property(self, factored_matrices): def test_svd_property_leading_ones(self, factored_matrices_leading_ones): for factored_matrix in factored_matrices_leading_ones: U, S, Vh = factored_matrix.svd() - assert torch.allclose( - factored_matrix.AB, U @ torch.diag_embed(S) @ Vh.mT, atol=1e-5 - ) + assert torch.allclose(factored_matrix.AB, U @ torch.diag_embed(S) @ Vh.mT, atol=1e-5) # test that U and Vh are unitary assert torch.allclose(U.mT @ U, torch.eye(U.shape[-1]), atol=1e-5) assert torch.allclose(Vh.mT @ Vh, torch.eye(Vh.shape[-1]), atol=1e-5) @@ -123,9 +119,7 @@ def test_pair_property(self, factored_matrices, random_matrices): def test_norm_property(self, factored_matrices): for factored_matrix in factored_matrices: - assert torch.allclose( - factored_matrix.norm(), factored_matrix.AB.norm(), atol=1e-5 - ) + assert torch.allclose(factored_matrix.norm(), factored_matrix.AB.norm(), atol=1e-5) def test_get_corner(self, factored_matrices): for factored_matrix in factored_matrices: @@ -143,9 +137,7 @@ def test_ndim(self, factored_matrices): def test_collapse_l(self, factored_matrices): for factored_matrix in factored_matrices: result = factored_matrix.collapse_l() - expected = factored_matrix.S[..., :, None] * utils.transpose( - factored_matrix.Vh - ) + expected = factored_matrix.S[..., :, None] * utils.transpose(factored_matrix.Vh) assert torch.allclose(result, expected) def test_collapse_r(self, factored_matrices): diff --git a/tests/unit/test_attention_mask.py b/tests/unit/test_attention_mask.py index df2c147ce..6b0951f5c 100644 --- a/tests/unit/test_attention_mask.py +++ b/tests/unit/test_attention_mask.py @@ -35,9 +35,7 @@ def attn_scores_hook(attn_scores, hook): return attn_scores def attn_hook(attn, hook): - assert torch.all( - attn[:, :, masked] == 0 - ), "Attention pattern attends outside the mask" + assert torch.all(attn[:, :, masked] == 0), "Attention pattern attends outside the mask" return attn diff --git a/tests/unit/test_cache_pos_slice.py b/tests/unit/test_cache_pos_slice.py index 41e6982b7..1b26c981d 100644 --- a/tests/unit/test_cache_pos_slice.py +++ b/tests/unit/test_cache_pos_slice.py @@ -21,16 +21,12 @@ def test_run_with_cache_pos_slice_keep_batch(): num_tokens = len(model.tokenizer.encode(prompt)) for i in range(-1, num_tokens + 1): - _, cache_with_slice = model.run_with_cache( - prompt, return_type=None, pos_slice=i - ) + _, cache_with_slice = model.run_with_cache(prompt, return_type=None, pos_slice=i) assert cache_with_slice["embed"].shape == torch.Size([1, 1, d_model]) assert cache_with_slice["q", 0].shape == torch.Size([1, 1, n_heads, d_head]) - assert torch.equal( - cache_no_slice["embed"][0, i, :], cache_with_slice["embed"][0, 0, :] - ) + assert torch.equal(cache_no_slice["embed"][0, i, :], cache_with_slice["embed"][0, 0, :]) assert torch.equal( cache_no_slice["pos_embed"][0, i, :], cache_with_slice["pos_embed"][0, 0, :] ) @@ -143,25 +139,17 @@ def test_run_with_cache_pos_slice_keep_batch(): def test_run_with_cache_pos_slice_remove_batch(): - _, cache_no_slice = model.run_with_cache( - prompt, remove_batch_dim=True, return_type=None - ) + _, cache_no_slice = model.run_with_cache(prompt, remove_batch_dim=True, return_type=None) num_tokens = len(model.tokenizer.encode(prompt)) for i in range(-1, num_tokens + 1): - _, cache_with_slice = model.run_with_cache( - prompt, remove_batch_dim=True, pos_slice=i - ) + _, cache_with_slice = model.run_with_cache(prompt, remove_batch_dim=True, pos_slice=i) assert cache_with_slice["embed"].shape == torch.Size([1, d_model]) assert cache_with_slice["q", 0].shape == torch.Size([1, n_heads, d_head]) - assert torch.equal( - cache_no_slice["embed"][i, :], cache_with_slice["embed"][0, :] - ) - assert torch.equal( - cache_no_slice["pos_embed"][i, :], cache_with_slice["pos_embed"][0, :] - ) + assert torch.equal(cache_no_slice["embed"][i, :], cache_with_slice["embed"][0, :]) + assert torch.equal(cache_no_slice["pos_embed"][i, :], cache_with_slice["pos_embed"][0, :]) for layer in range(n_layers): assert torch.equal( diff --git a/tests/unit/test_create_hooked_encoder.py b/tests/unit/test_create_hooked_encoder.py index a1adc7ef3..45e549aaa 100644 --- a/tests/unit/test_create_hooked_encoder.py +++ b/tests/unit/test_create_hooked_encoder.py @@ -6,9 +6,7 @@ @pytest.fixture def cfg(): - return HookedTransformerConfig( - d_head=4, d_model=12, n_ctx=5, n_layers=3, act_fn="gelu" - ) + return HookedTransformerConfig(d_head=4, d_model=12, n_ctx=5, n_layers=3, act_fn="gelu") def test_pass_tokenizer(cfg): diff --git a/tests/unit/test_grouped_query_attention.py b/tests/unit/test_grouped_query_attention.py index 885ec39a0..633bb9180 100644 --- a/tests/unit/test_grouped_query_attention.py +++ b/tests/unit/test_grouped_query_attention.py @@ -75,8 +75,6 @@ def test_grouped_query_attention_output_is_correct(): value_input = torch.rand((1, 5, d_model)) regular_attn_output = regular_attention(query_input, key_input, value_input) - grouped_query_attn_output = grouped_query_attention( - query_input, key_input, value_input - ) + grouped_query_attn_output = grouped_query_attention(query_input, key_input, value_input) assert torch.equal(regular_attn_output, grouped_query_attn_output) diff --git a/tests/unit/test_head_detector.py b/tests/unit/test_head_detector.py index a039602a2..e465f8805 100644 --- a/tests/unit/test_head_detector.py +++ b/tests/unit/test_head_detector.py @@ -284,9 +284,7 @@ def test_detect_head_exclude_bos(error_measure: ErrorMeasure, expected: torch.Te ("abs", expected_previous_exclude_current_token_match_abs), ), ) -def test_detect_head_exclude_current_token( - error_measure: ErrorMeasure, expected: torch.Tensor -): +def test_detect_head_exclude_current_token(error_measure: ErrorMeasure, expected: torch.Tensor): assert torch.allclose( detect_head( model, @@ -381,9 +379,7 @@ def test_detect_head_with_invalid_detection_pattern(): class Test_detect_head_non_lower_triangular_detection_pattern: - detection_pattern = torch.tril( - torch.ones(test_duplicated_seq_len, test_duplicated_seq_len) - ) + detection_pattern = torch.tril(torch.ones(test_duplicated_seq_len, test_duplicated_seq_len)) def test_no_error(self): detect_head( @@ -440,16 +436,14 @@ def test_allclose_abs(self): def test_isclose_mul(self): assert math.isclose( torch.sum(self.match_abs), - self.match_mul[0, 0].item() - - (model.cfg.n_layers * model.cfg.n_heads - 1), + self.match_mul[0, 0].item() - (model.cfg.n_layers * model.cfg.n_heads - 1), abs_tol=ATOL, ) def test_isclose_abs(self): assert math.isclose( torch.sum(self.match_abs), - self.match_abs[0, 0].item() - - (model.cfg.n_layers * model.cfg.n_heads - 1), + self.match_abs[0, 0].item() - (model.cfg.n_layers * model.cfg.n_heads - 1), abs_tol=ATOL, ) @@ -486,16 +480,14 @@ def test_allclose_abs(self): def test_isclose_mul(self): assert math.isclose( torch.sum(self.match_mul), - self.match_mul[0, 0].item() - - (model.cfg.n_layers * model.cfg.n_heads - 1), + self.match_mul[0, 0].item() - (model.cfg.n_layers * model.cfg.n_heads - 1), abs_tol=ATOL, ) def test_isclose_abs(self): assert math.isclose( torch.sum(self.match_abs), - self.match_abs[0, 0].item() - - (model.cfg.n_layers * model.cfg.n_heads - 1), + self.match_abs[0, 0].item() - (model.cfg.n_layers * model.cfg.n_heads - 1), abs_tol=ATOL, ) @@ -532,16 +524,14 @@ def test_allclose_abs(self): def test_isclose_mul(self): assert math.isclose( torch.sum(self.match_mul), - self.match_mul[0, 0].item() - - (model.cfg.n_layers * model.cfg.n_heads - 1), + self.match_mul[0, 0].item() - (model.cfg.n_layers * model.cfg.n_heads - 1), abs_tol=ATOL, ) def test_isclose_abs(self): assert math.isclose( torch.sum(self.match_abs), - self.match_abs[0, 0].item() - - (model.cfg.n_layers * model.cfg.n_heads - 1), + self.match_abs[0, 0].item() - (model.cfg.n_layers * model.cfg.n_heads - 1), abs_tol=ATOL, ) @@ -578,16 +568,14 @@ def test_allclose_abs(self): def test_isclose_mul(self): assert math.isclose( torch.sum(self.match_mul), - self.match_mul[0, 0].item() - - (model.cfg.n_layers * model.cfg.n_heads - 1), + self.match_mul[0, 0].item() - (model.cfg.n_layers * model.cfg.n_heads - 1), abs_tol=ATOL, ) def test_isclose_abs(self): assert math.isclose( torch.sum(self.match_abs), - self.match_abs[0, 0].item() - - (model.cfg.n_layers * model.cfg.n_heads - 1), + self.match_abs[0, 0].item() - (model.cfg.n_layers * model.cfg.n_heads - 1), abs_tol=ATOL, ) @@ -632,9 +620,7 @@ class Test_duplicate_token_head: def test1(self): assert ( - get_duplicate_token_head_detection_pattern( - model.to_tokens(test_regular_sequence).cpu() - ) + get_duplicate_token_head_detection_pattern(model.to_tokens(test_regular_sequence).cpu()) == torch.zeros(4, 4) ).all() @@ -655,9 +641,7 @@ class Test_induction_head_detection: def test1(self): assert ( - get_duplicate_token_head_detection_pattern( - model.to_tokens(test_regular_sequence).cpu() - ) + get_duplicate_token_head_detection_pattern(model.to_tokens(test_regular_sequence).cpu()) == torch.zeros(4, 4) ).all() diff --git a/tests/unit/test_hooks.py b/tests/unit/test_hooks.py index 1c41b8e24..231a57150 100644 --- a/tests/unit/test_hooks.py +++ b/tests/unit/test_hooks.py @@ -116,9 +116,7 @@ def test_remove_hook(): model.add_perma_hook(embed, c.inc) assert len(model.hook_dict["hook_embed"].fwd_hooks) == 1 # 1 after adding model.remove_all_hook_fns() - assert ( - len(model.hook_dict["hook_embed"].fwd_hooks) == 1 - ) # permanent not removed without flag + assert len(model.hook_dict["hook_embed"].fwd_hooks) == 1 # permanent not removed without flag model.remove_all_hook_fns(including_permanent=True) assert len(model.hook_dict["hook_embed"].fwd_hooks) == 0 # removed now model.run_with_hooks(prompt, fwd_hooks=[]) @@ -182,11 +180,7 @@ def identity_hook(z, hook): @pytest.mark.parametrize( "zero_attach_pos,prepend", - [ - (zero_attach_pos, prepend) - for zero_attach_pos in range(2) - for prepend in [True, False] - ], + [(zero_attach_pos, prepend) for zero_attach_pos in range(2) for prepend in [True, False]], ) def test_prepending_hooks(zero_attach_pos, prepend): """Add two hooks to a model: one that sets last layer activations to all 0s diff --git a/tests/unit/test_kv_cache.py b/tests/unit/test_kv_cache.py index b69b6b3b9..435d7fa42 100644 --- a/tests/unit/test_kv_cache.py +++ b/tests/unit/test_kv_cache.py @@ -69,9 +69,7 @@ def test_multiple_new_tokens(pretrained): past_kv_cache=past_kv_cache, ) assert t.allclose(no_cache_logits[:, -1], with_cache_logits[:, -1], atol=atol) - assert t.allclose( - no_cache_logits[:, -new_tokens_len:], with_cache_logits, atol=atol - ) + assert t.allclose(no_cache_logits[:, -new_tokens_len:], with_cache_logits, atol=atol) @pytest.mark.parametrize("pre_padding", ["left", "right", None]) @@ -95,17 +93,13 @@ def test_multi_token_batch(pretrained, pre_padding, post_padding): " by the candidate", ] - first_post_prompt_tokens = model.to_tokens( - padded_batch_post_prompts[0], prepend_bos=False - ) + first_post_prompt_tokens = model.to_tokens(padded_batch_post_prompts[0], prepend_bos=False) first_full_prompt_tokens = t.cat( [model.to_tokens(padded_batch_pre_prompts[0]), first_post_prompt_tokens], dim=-1 ) first_post_prompt_len = first_post_prompt_tokens.shape[-1] first_prompt_no_cache_logits = model(first_full_prompt_tokens) - first_post_prompt_no_cache_logits = first_prompt_no_cache_logits[ - 0, -first_post_prompt_len: - ] + first_post_prompt_no_cache_logits = first_prompt_no_cache_logits[0, -first_post_prompt_len:] if pre_padding is None: batch_pre_prompt_tokens = model.to_tokens(unpadded_batch_pre_prompts) @@ -116,9 +110,7 @@ def test_multi_token_batch(pretrained, pre_padding, post_padding): ) if post_padding is None: - batch_post_prompt_tokens = model.to_tokens( - unpadded_batch_post_prompts, prepend_bos=False - ) + batch_post_prompt_tokens = model.to_tokens(unpadded_batch_post_prompts, prepend_bos=False) else: assert post_padding == "left" or post_padding == "right" batch_post_prompt_tokens = model.to_tokens( @@ -130,9 +122,7 @@ def test_multi_token_batch(pretrained, pre_padding, post_padding): past_kv_cache = HookedTransformerKeyValueCache.init_cache( model.cfg, model.cfg.device, batch_pre_prompt_tokens.shape[0] ) - model( - batch_pre_prompt_tokens, past_kv_cache=past_kv_cache, padding_side=pre_padding - ) + model(batch_pre_prompt_tokens, past_kv_cache=past_kv_cache, padding_side=pre_padding) past_kv_cache.freeze() with_cache_logits = model( batch_post_prompt_tokens, @@ -141,14 +131,10 @@ def test_multi_token_batch(pretrained, pre_padding, post_padding): prepend_bos=False, ) if post_padding == "left" or post_padding is None: - first_post_prompt_with_cache_logits = with_cache_logits[ - 0, -first_post_prompt_len: - ] + first_post_prompt_with_cache_logits = with_cache_logits[0, -first_post_prompt_len:] else: assert post_padding == "right" - first_post_prompt_with_cache_logits = with_cache_logits[ - 0, :first_post_prompt_len - ] + first_post_prompt_with_cache_logits = with_cache_logits[0, :first_post_prompt_len] no_cache_probs = t.softmax(first_post_prompt_no_cache_logits, dim=-1) with_cache_probs = t.softmax(first_post_prompt_with_cache_logits, dim=-1) @@ -249,9 +235,7 @@ def test_kv_cache_and_start_at_layer(pretrained): _, toks, shortformer_pos_embed, attn_mask = model.input_to_embed( single_new_token, past_kv_cache=past_kv_cache ) - _, cache = model.run_with_cache( - single_new_token, stop_at_layer=4, past_kv_cache=past_kv_cache - ) + _, cache = model.run_with_cache(single_new_token, stop_at_layer=4, past_kv_cache=past_kv_cache) resid_3 = cache["blocks.3.hook_resid_pre"] with_cache_logits = model( resid_3, diff --git a/tests/unit/test_left_padding.py b/tests/unit/test_left_padding.py index b40f97fba..a4961dc74 100644 --- a/tests/unit/test_left_padding.py +++ b/tests/unit/test_left_padding.py @@ -89,9 +89,7 @@ def test_pos_embed(self, model, padding_side, prepend_bos): attended_output_pos_embed = output_pos_embed[attention_mask.bool()] - assert torch.allclose( - attended_output_pos_embed, target_output_pos_embed, atol=1e-4 - ) + assert torch.allclose(attended_output_pos_embed, target_output_pos_embed, atol=1e-4) # padded positions should have zero pos_embed assert output_pos_embed[~attention_mask.bool()].sum() == 0 @@ -117,9 +115,7 @@ def test_pos_embed_with_cache(self, model, padding_side, prepend_bos): model.tokenizer, tokens, prepend_bos ) # [batch pos] past_kv_cache.append_attention_mask(attention_mask) - attention_mask_2 = utils.get_attention_mask( - model.tokenizer, tokens_2, False - ) # [batch pos] + attention_mask_2 = utils.get_attention_mask(model.tokenizer, tokens_2, False) # [batch pos] cached_attention_mask = past_kv_cache.append_attention_mask(attention_mask_2) output_pos_embed = model.pos_embed( @@ -141,9 +137,7 @@ def test_pos_embed_with_cache(self, model, padding_side, prepend_bos): attended_output_pos_embed = output_pos_embed[attention_mask_2.bool()] - assert torch.allclose( - attended_output_pos_embed, target_output_pos_embed, atol=1e-4 - ) + assert torch.allclose(attended_output_pos_embed, target_output_pos_embed, atol=1e-4) # padded positions should have zero pos_embed assert output_pos_embed[~attention_mask_2.bool()].sum() == 0 diff --git a/tests/unit/test_only_tokenizer.py b/tests/unit/test_only_tokenizer.py index 867224d3d..fa2642b1b 100644 --- a/tests/unit/test_only_tokenizer.py +++ b/tests/unit/test_only_tokenizer.py @@ -30,17 +30,13 @@ def __init__( elif self.cfg.tokenizer_name is not None: # If we have a tokenizer name, we can load it from HuggingFace self.set_tokenizer( - AutoTokenizer.from_pretrained( - self.cfg.tokenizer_name, add_bos_token=True - ), + AutoTokenizer.from_pretrained(self.cfg.tokenizer_name, add_bos_token=True), default_padding_side=default_padding_side, ) else: # If no tokenizer name is provided, we assume we're training on an algorithmic task and will pass in tokens # directly. In this case, we don't need a tokenizer. - assert ( - self.cfg.d_vocab != -1 - ), "Must provide a tokenizer if d_vocab is not provided" + assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided" self.tokenizer = None if default_padding_side != "right": logging.warning( @@ -101,9 +97,7 @@ class TestTokenizer: # helper functions def get_num_tokens_in_prompt(self, model, prompt, intended_prepend_bos): - tokenizer = AutoTokenizer.from_pretrained( - model.tokenizer.name_or_path, add_bos_token=False - ) + tokenizer = AutoTokenizer.from_pretrained(model.tokenizer.name_or_path, add_bos_token=False) tokens = tokenizer( prompt, )["input_ids"] @@ -126,9 +120,7 @@ def check_tokens_length(self, model, str_tokens, tokens, intended_prepend_bos): assert len(str_tokens) == tokens.shape[1] == expected_num_tokens def check_prompt(self, model, intended_prepend_bos, overriding_prepend_bos=None): - str_tokens = model.to_str_tokens( - self.prompt, prepend_bos=overriding_prepend_bos - ) + str_tokens = model.to_str_tokens(self.prompt, prepend_bos=overriding_prepend_bos) tokens = model.to_tokens(self.prompt, prepend_bos=overriding_prepend_bos) self.check_first_token(model, str_tokens, tokens, intended_prepend_bos) @@ -164,9 +156,7 @@ def check_prompts( if model.tokenizer.pad_token_id != model.tokenizer.bos_token_id: if intended_prepend_bos: - assert (tokens == model.tokenizer.bos_token_id).sum() == tokens.shape[ - 0 - ], tokens + assert (tokens == model.tokenizer.bos_token_id).sum() == tokens.shape[0], tokens else: assert (tokens == model.tokenizer.bos_token_id).sum() == 0, tokens @@ -220,9 +210,7 @@ def test_given_defaults(self, model_name): @pytest.mark.parametrize("intended_prepend_bos", [True, False]) @pytest.mark.parametrize("intended_padding_side", ["left", "right"]) - def test_changing_defaults( - self, model, intended_prepend_bos, intended_padding_side - ): + def test_changing_defaults(self, model, intended_prepend_bos, intended_padding_side): model.tokenizer.padding_side = intended_padding_side model.cfg.default_prepend_bos = intended_prepend_bos @@ -231,9 +219,7 @@ def test_changing_defaults( @pytest.mark.parametrize("intended_prepend_bos", [True, False]) @pytest.mark.parametrize("intended_padding_side", ["left", "right"]) - def test_overriding_defaults( - self, model, intended_prepend_bos, intended_padding_side - ): + def test_overriding_defaults(self, model, intended_prepend_bos, intended_padding_side): self.check_prompt(model, intended_prepend_bos, intended_prepend_bos) self.check_prompts( model, diff --git a/tests/unit/test_prepend_bos.py b/tests/unit/test_prepend_bos.py index 949939936..afb85d933 100644 --- a/tests/unit/test_prepend_bos.py +++ b/tests/unit/test_prepend_bos.py @@ -9,9 +9,7 @@ class TestPrependBos: # helper functions def get_num_tokens_in_prompt(self, model, prompt, intended_prepend_bos): - tokenizer = AutoTokenizer.from_pretrained( - model.tokenizer.name_or_path, add_bos_token=False - ) + tokenizer = AutoTokenizer.from_pretrained(model.tokenizer.name_or_path, add_bos_token=False) tokens = tokenizer( prompt, )["input_ids"] @@ -26,15 +24,11 @@ def check_first_token(self, model, str_tokens, tokens, intended_prepend_bos): assert str_tokens[0] != model.tokenizer.bos_token assert tokens[0][0] != model.tokenizer.bos_token_id - def check_tokens_length( - self, model, logits, str_tokens, tokens, intended_prepend_bos - ): + def check_tokens_length(self, model, logits, str_tokens, tokens, intended_prepend_bos): expected_num_tokens = self.get_num_tokens_in_prompt( model, self.prompt, intended_prepend_bos ) - assert ( - logits.shape[1] == len(str_tokens) == tokens.shape[1] == expected_num_tokens - ) + assert logits.shape[1] == len(str_tokens) == tokens.shape[1] == expected_num_tokens # fixtures @pytest.fixture(scope="class", params=["gpt2", "facebook/opt-125m"]) @@ -59,13 +53,9 @@ def test_default_prepend_bos(self, model_name): tokens = model.to_tokens(self.prompt) # [batch pos] self.check_first_token(model, str_tokens, tokens, intended_prepend_bos) - self.check_tokens_length( - model, logits, str_tokens, tokens, intended_prepend_bos - ) + self.check_tokens_length(model, logits, str_tokens, tokens, intended_prepend_bos) - bos_position = model.get_token_position( - model.tokenizer.bos_token_id, self.prompt - ) + bos_position = model.get_token_position(model.tokenizer.bos_token_id, self.prompt) assert bos_position == 0 def test_default_prepend_bos_to_false(self, model_name): @@ -80,34 +70,24 @@ def test_default_prepend_bos_to_false(self, model_name): tokens = model.to_tokens(self.prompt) self.check_first_token(model, str_tokens, tokens, intended_prepend_bos) - self.check_tokens_length( - model, logits, str_tokens, tokens, intended_prepend_bos - ) + self.check_tokens_length(model, logits, str_tokens, tokens, intended_prepend_bos) @pytest.mark.parametrize("intended_prepend_bos", [True, False]) def test_override_prepend_bos(self, model, intended_prepend_bos): for default_prepend_bos in [True, False]: model.cfg.default_prepend_bos = default_prepend_bos - logits = model( - self.prompt, prepend_bos=intended_prepend_bos - ) # [batch pos d_vocab] - str_tokens = model.to_str_tokens( - self.prompt, prepend_bos=intended_prepend_bos - ) + logits = model(self.prompt, prepend_bos=intended_prepend_bos) # [batch pos d_vocab] + str_tokens = model.to_str_tokens(self.prompt, prepend_bos=intended_prepend_bos) tokens = model.to_tokens(self.prompt, prepend_bos=intended_prepend_bos) self.check_first_token(model, str_tokens, tokens, intended_prepend_bos) - self.check_tokens_length( - model, logits, str_tokens, tokens, intended_prepend_bos - ) + self.check_tokens_length(model, logits, str_tokens, tokens, intended_prepend_bos) def test_prepend_bos_with_get_token_position(self, model_name): model = HookedTransformer.from_pretrained(model_name) - bos_position = model.get_token_position( - model.tokenizer.bos_token_id, self.prompt - ) + bos_position = model.get_token_position(model.tokenizer.bos_token_id, self.prompt) assert bos_position == 0 with pytest.raises(AssertionError): @@ -117,9 +97,7 @@ def test_prepend_bos_with_get_token_position(self, model_name): model.cfg.default_prepend_bos = False with pytest.raises(AssertionError): - bos_position = model.get_token_position( - model.tokenizer.bos_token_id, self.prompt - ) + bos_position = model.get_token_position(model.tokenizer.bos_token_id, self.prompt) bos_position = model.get_token_position( model.tokenizer.bos_token_id, self.prompt, prepend_bos=True diff --git a/tests/unit/test_start_at_layer.py b/tests/unit/test_start_at_layer.py index c87779bca..f1d007829 100644 --- a/tests/unit/test_start_at_layer.py +++ b/tests/unit/test_start_at_layer.py @@ -124,9 +124,7 @@ def test_no_start_logit_output(setup_data: Dict[str, Any]): def test_no_start_none_output(setup_data: Dict[str, Any]): model, rand_input = setup_data["model"], setup_data["rand_input"] - output, cache = model.run_with_cache( - rand_input, start_at_layer=None, return_type=None - ) + output, cache = model.run_with_cache(rand_input, start_at_layer=None, return_type=None) assert output is None assert "hook_embed" in cache.keys() @@ -183,11 +181,7 @@ def test_start_at_layer_kwargs(): shortformer_pos_embed, attention_mask, ) = model.input_to_embed(input) - assert ( - tokens is not None - and shortformer_pos_embed is not None - and attention_mask is not None - ) + assert tokens is not None and shortformer_pos_embed is not None and attention_mask is not None start_at_layer_output = model( rand_embed, diff --git a/tests/unit/test_stop_at_layer.py b/tests/unit/test_stop_at_layer.py index 3bbae6cd4..2692c8f49 100644 --- a/tests/unit/test_stop_at_layer.py +++ b/tests/unit/test_stop_at_layer.py @@ -220,9 +220,7 @@ def test_no_stop_no_output(): ) rand_input = torch.randint(0, 20, (2, 10)) - output, cache = model.run_with_cache( - rand_input, stop_at_layer=None, return_type=None - ) + output, cache = model.run_with_cache(rand_input, stop_at_layer=None, return_type=None) assert output is None assert "hook_embed" in cache.keys() diff --git a/tests/unit/test_svd_interpreter.py b/tests/unit/test_svd_interpreter.py index d23d30ac5..face0643b 100644 --- a/tests/unit/test_svd_interpreter.py +++ b/tests/unit/test_svd_interpreter.py @@ -114,9 +114,7 @@ def test_svd_interpreter_returns_different_answers_for_different_models(): def test_svd_interpreter_fails_on_invalid_vector_type(): svd_interpreter = SVDInterpreter(model) with pytest.raises(BeartypeCallHintParamViolation) as e: - svd_interpreter.get_singular_vectors( - "test", layer_index=0, num_vectors=4, head_index=0 - ) + svd_interpreter.get_singular_vectors("test", layer_index=0, num_vectors=4, head_index=0) def test_svd_interpreter_fails_on_not_passing_required_head_index(): @@ -130,9 +128,7 @@ def test_svd_interpreter_fails_on_invalid_layer_index(): svd_interpreter = SVDInterpreter(model) for vector in VECTOR_TYPES: with pytest.raises(AssertionError) as e: - svd_interpreter.get_singular_vectors( - vector, layer_index=2, num_vectors=4, head_index=0 - ) + svd_interpreter.get_singular_vectors(vector, layer_index=2, num_vectors=4, head_index=0) assert str(e.value) == "Layer index must be between 0 and 1 but got 2" @@ -140,7 +136,5 @@ def test_svd_interpreter_fails_on_invalid_head_index(): # Only OV uses head index. svd_interpreter = SVDInterpreter(model) with pytest.raises(AssertionError) as e: - svd_interpreter.get_singular_vectors( - "OV", layer_index=0, num_vectors=4, head_index=8 - ) + svd_interpreter.get_singular_vectors("OV", layer_index=0, num_vectors=4, head_index=8) assert str(e.value) == "Head index must be between 0 and 7 but got 8" diff --git a/tests/unit/test_tokenization_methods.py b/tests/unit/test_tokenization_methods.py index d795d453a..acba3ebd7 100644 --- a/tests/unit/test_tokenization_methods.py +++ b/tests/unit/test_tokenization_methods.py @@ -58,9 +58,7 @@ def test_to_tokens_device(): s = "Hello, world!" tokens1 = model.to_tokens(s, move_to_device=False) tokens2 = model.to_tokens(s, move_to_device=True) - assert equal( - tokens1, tokens2 - ), "move to device has no effect when running tests on CPU" + assert equal(tokens1, tokens2), "move to device has no effect when running tests on CPU" def test_to_tokens_truncate(): @@ -125,9 +123,7 @@ def test_get_token_position_not_found(): input = "There were some biomolecules" with pytest.raises(AssertionError) as exc_info: model.get_token_position(single, input) - assert ( - str(exc_info.value) == f"The token does not occur in the prompt" - ), "assertion error" + assert str(exc_info.value) == f"The token does not occur in the prompt", "assertion error" def test_get_token_position_str(): diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index f08098c26..7feec34a0 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -274,21 +274,13 @@ def test_test_prompt( def test_override_or_use_default_value(): # Case when override is not None assert utils.override_or_use_default_value(default_flag=True, override=True) == True - assert ( - utils.override_or_use_default_value(default_flag=True, override=False) == False - ) - assert ( - utils.override_or_use_default_value(default_flag=False, override=True) == True - ) - assert ( - utils.override_or_use_default_value(default_flag=False, override=False) == False - ) + assert utils.override_or_use_default_value(default_flag=True, override=False) == False + assert utils.override_or_use_default_value(default_flag=False, override=True) == True + assert utils.override_or_use_default_value(default_flag=False, override=False) == False # Case when override is None assert utils.override_or_use_default_value(default_flag=True, override=None) == True - assert ( - utils.override_or_use_default_value(default_flag=False, override=None) == False - ) + assert utils.override_or_use_default_value(default_flag=False, override=None) == False # Case when override is not passed assert utils.override_or_use_default_value(default_flag=True) == True @@ -322,9 +314,7 @@ def model(self, model_name): @pytest.mark.parametrize("padding_side", ["left", "right"]) @pytest.mark.parametrize("prepend_bos", [True, False]) @pytest.mark.parametrize("prompts_with_sep", [True, False]) - def test_get_attention_mask( - self, model, padding_side, prepend_bos, prompts_with_sep - ): + def test_get_attention_mask(self, model, padding_side, prepend_bos, prompts_with_sep): # setup model.tokenizer.padding_side = padding_side model.tokenizer.sep_token_id = model.tokenizer.pad_token_id @@ -397,9 +387,7 @@ def test_calc_fan_in_fan_out(): assert fan_out == 3 # Test for the case when the tensor is 3D - tensor_3d = nn.Parameter( - torch.rand(2, 25, 5) - ) # 2 x 25 x 5, I'm not writing this out + tensor_3d = nn.Parameter(torch.rand(2, 25, 5)) # 2 x 25 x 5, I'm not writing this out fan_in, fan_out = utils.calc_fan_in_and_fan_out(tensor_3d) assert fan_in == 25 assert fan_out == 10 @@ -546,9 +534,7 @@ def test_init_xavier_normal(self, d_model, d_mlp): std = np.sqrt(2 / (d_mlp + d_model)) assert np.isclose(y.std().detach().numpy(), std, rtol=1e-2) - z = nn.Parameter( - torch.empty(d_model * 123) - ) # need to make this larger so std is accurate + z = nn.Parameter(torch.empty(d_model * 123)) # need to make this larger so std is accurate utils.init_xavier_normal_(z) std = np.sqrt(2 / (1 + d_model * 123)) assert np.isclose(z.std().detach().numpy(), std, rtol=1e-2) diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index 31906fade..caff121c3 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -114,9 +114,7 @@ class ActivationCache: Whether the activations have a batch dimension. """ - def __init__( - self, cache_dict: Dict[str, torch.Tensor], model, has_batch_dim: bool = True - ): + def __init__(self, cache_dict: Dict[str, torch.Tensor], model, has_batch_dim: bool = True): self.cache_dict = cache_dict self.model = model self.has_batch_dim = has_batch_dim @@ -138,9 +136,7 @@ def remove_batch_dim(self) -> ActivationCache: self.cache_dict[key] = self.cache_dict[key][0] self.has_batch_dim = False else: - logging.warning( - "Tried removing batch dimension after already having removed it." - ) + logging.warning("Tried removing batch dimension after already having removed it.") return self def __repr__(self) -> str: @@ -207,9 +203,7 @@ def to(self, device: Union[str, torch.device], move_model=False) -> ActivationCa DeprecationWarning, ) - self.cache_dict = { - key: value.to(device) for key, value in self.cache_dict.items() - } + self.cache_dict = {key: value.to(device) for key, value in self.cache_dict.items()} if move_model: self.model.to(device) @@ -301,9 +295,7 @@ def __iter__(self) -> Iterator[str]: """ return self.cache_dict.__iter__() - def apply_slice_to_batch_dim( - self, batch_slice: Union[Slice, SliceInput] - ) -> ActivationCache: + def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]) -> ActivationCache: """Apply a Slice to the Batch Dimension. Args: @@ -321,12 +313,9 @@ def apply_slice_to_batch_dim( ), "Cannot index into a cache without a batch dim" still_has_batch_dim = (batch_slice.mode != "int") and self.has_batch_dim new_cache_dict = { - name: batch_slice.apply(param, dim=0) - for name, param in self.cache_dict.items() + name: batch_slice.apply(param, dim=0) for name, param in self.cache_dict.items() } - return ActivationCache( - new_cache_dict, self.model, has_batch_dim=still_has_batch_dim - ) + return ActivationCache(new_cache_dict, self.model, has_batch_dim=still_has_batch_dim) def accumulated_resid( self, @@ -338,9 +327,7 @@ def accumulated_resid( return_labels: bool = False, ) -> Union[ Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], - Tuple[ - Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str] - ], + Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]], ]: """Accumulated Residual Stream. @@ -443,9 +430,7 @@ def accumulated_resid( components_list = [] for l in range(layer + 1): if l == self.model.cfg.n_layers: - components_list.append( - self[("resid_post", self.model.cfg.n_layers - 1)] - ) + components_list.append(self[("resid_post", self.model.cfg.n_layers - 1)]) labels.append("final_post") continue components_list.append(self[("resid_pre", l)]) @@ -466,9 +451,7 @@ def accumulated_resid( def logit_attrs( self, - residual_stack: Float[ - torch.Tensor, "num_components *batch_and_pos_dims d_model" - ], + residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], tokens: Union[ str, int, @@ -549,9 +532,7 @@ def logit_attrs( if incorrect_tokens is not None: if isinstance(incorrect_tokens, str): - incorrect_tokens = torch.as_tensor( - self.model.to_single_token(incorrect_tokens) - ) + incorrect_tokens = torch.as_tensor(self.model.to_single_token(incorrect_tokens)) elif isinstance(incorrect_tokens, int): incorrect_tokens = torch.as_tensor(incorrect_tokens) @@ -564,9 +545,8 @@ def logit_attrs( ) # If incorrect_tokens was provided, take the logit difference - logit_directions = ( - logit_directions - - self.model.tokens_to_residual_directions(incorrect_tokens) + logit_directions = logit_directions - self.model.tokens_to_residual_directions( + incorrect_tokens ) scaled_residual_stack = self.apply_ln_to_stack( @@ -594,9 +574,7 @@ def decompose_resid( return_labels: bool = False, ) -> Union[ Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], - Tuple[ - Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str] - ], + Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]], ]: """Decompose the Residual Stream. @@ -686,9 +664,7 @@ def compute_head_results( be useful if you forget. """ if "blocks.0.attn.hook_result" in self.cache_dict: - logging.warning( - "Tried to compute head results when they were already cached" - ) + logging.warning("Tried to compute head results when they were already cached") return for l in range(self.model.cfg.n_layers): # Note that we haven't enabled set item on this object so we need to edit the underlying @@ -861,9 +837,7 @@ def stack_neuron_results( apply_ln: bool = False, ) -> Union[ Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], - Tuple[ - Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str] - ], + Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]], ]: """Stack Neuron Results @@ -911,9 +885,7 @@ def stack_neuron_results( for l in range(layer): # Note that this has shape batch x pos x head_index x d_model components.append( - self.get_neuron_results( - l, pos_slice=pos_slice, neuron_slice=neuron_slice - ) + self.get_neuron_results(l, pos_slice=pos_slice, neuron_slice=neuron_slice) ) labels.extend([f"L{l}N{h}" for h in neuron_labels]) if components: @@ -947,9 +919,7 @@ def stack_neuron_results( def apply_ln_to_stack( self, - residual_stack: Float[ - torch.Tensor, "num_components *batch_and_pos_dims d_model" - ], + residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], layer: Optional[int] = None, mlp_input: bool = False, pos_slice: Union[Slice, SliceInput] = None, @@ -1100,9 +1070,7 @@ def get_full_resid_decomposition( labels.append("pos_embed") components.append(pos_slice.apply(self["pos_embed"], -2)[None]) # If we didn't expand the neurons, the MLP biases are already included in the MLP outputs. - bias = self.model.accumulated_bias( - layer, mlp_input, include_mlp_biases=expand_neurons - ) + bias = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons) bias = bias.expand((1,) + head_stack.shape[1:]) labels.append("bias") components.append(bias) diff --git a/transformer_lens/FactoredMatrix.py b/transformer_lens/FactoredMatrix.py index e037d45dc..1e1c813a6 100644 --- a/transformer_lens/FactoredMatrix.py +++ b/transformer_lens/FactoredMatrix.py @@ -218,9 +218,7 @@ def __getitem__(self, idx: Union[int, Tuple]) -> FactoredMatrix: elif length == len(self.shape): idx = self._convert_to_slice(idx, -1) idx = self._convert_to_slice(idx, -2) - return FactoredMatrix( - self.A[idx[:-1]], self.B[idx[:-2] + (slice(None), idx[-1])] - ) + return FactoredMatrix(self.A[idx[:-1]], self.B[idx[:-2] + (slice(None), idx[-1])]) else: raise ValueError( f"{idx} is too long an index for a FactoredMatrix with shape {self.shape}" diff --git a/transformer_lens/HookedEncoder.py b/transformer_lens/HookedEncoder.py index 51e357788..cc0d4e880 100644 --- a/transformer_lens/HookedEncoder.py +++ b/transformer_lens/HookedEncoder.py @@ -48,9 +48,7 @@ def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs): ) self.cfg = cfg - assert ( - self.cfg.n_devices == 1 - ), "Multiple devices not supported for HookedEncoder" + assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder" if tokenizer is not None: self.tokenizer = tokenizer elif self.cfg.tokenizer_name is not None: @@ -60,17 +58,13 @@ def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs): if self.cfg.d_vocab == -1: # If we have a tokenizer, vocab size can be inferred from it. - assert ( - self.tokenizer is not None - ), "Must provide a tokenizer if d_vocab is not provided" + assert self.tokenizer is not None, "Must provide a tokenizer if d_vocab is not provided" self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 if self.cfg.d_vocab_out == -1: self.cfg.d_vocab_out = self.cfg.d_vocab self.embed = BertEmbed(self.cfg) - self.blocks = nn.ModuleList( - [BertBlock(self.cfg) for _ in range(self.cfg.n_layers)] - ) + self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]) self.mlm_head = BertMLMHead(cfg) self.unembed = Unembed(self.cfg) @@ -133,9 +127,7 @@ def forward( else None ) additive_attention_mask = ( - torch.where(mask == 1, large_negative_number, 0) - if mask is not None - else None + torch.where(mask == 1, large_negative_number, 0) if mask is not None else None ) for block in self.blocks: @@ -177,9 +169,7 @@ def run_with_cache( *model_args, remove_batch_dim=remove_batch_dim, **kwargs ) if return_cache_object: - cache = ActivationCache( - cache_dict, self, has_batch_dim=not remove_batch_dim - ) + cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) return out, cache else: return out, cache_dict @@ -302,86 +292,62 @@ def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: @property def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: """Stacks the key weights across all layers""" - return torch.stack( - [cast(BertBlock, block).attn.W_K for block in self.blocks], dim=0 - ) + return torch.stack([cast(BertBlock, block).attn.W_K for block in self.blocks], dim=0) @property def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: """Stacks the query weights across all layers""" - return torch.stack( - [cast(BertBlock, block).attn.W_Q for block in self.blocks], dim=0 - ) + return torch.stack([cast(BertBlock, block).attn.W_Q for block in self.blocks], dim=0) @property def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: """Stacks the value weights across all layers""" - return torch.stack( - [cast(BertBlock, block).attn.W_V for block in self.blocks], dim=0 - ) + return torch.stack([cast(BertBlock, block).attn.W_V for block in self.blocks], dim=0) @property def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: """Stacks the attn output weights across all layers""" - return torch.stack( - [cast(BertBlock, block).attn.W_O for block in self.blocks], dim=0 - ) + return torch.stack([cast(BertBlock, block).attn.W_O for block in self.blocks], dim=0) @property def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: """Stacks the MLP input weights across all layers""" - return torch.stack( - [cast(BertBlock, block).mlp.W_in for block in self.blocks], dim=0 - ) + return torch.stack([cast(BertBlock, block).mlp.W_in for block in self.blocks], dim=0) @property def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: """Stacks the MLP output weights across all layers""" - return torch.stack( - [cast(BertBlock, block).mlp.W_out for block in self.blocks], dim=0 - ) + return torch.stack([cast(BertBlock, block).mlp.W_out for block in self.blocks], dim=0) @property def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: """Stacks the key biases across all layers""" - return torch.stack( - [cast(BertBlock, block).attn.b_K for block in self.blocks], dim=0 - ) + return torch.stack([cast(BertBlock, block).attn.b_K for block in self.blocks], dim=0) @property def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: """Stacks the query biases across all layers""" - return torch.stack( - [cast(BertBlock, block).attn.b_Q for block in self.blocks], dim=0 - ) + return torch.stack([cast(BertBlock, block).attn.b_Q for block in self.blocks], dim=0) @property def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: """Stacks the value biases across all layers""" - return torch.stack( - [cast(BertBlock, block).attn.b_V for block in self.blocks], dim=0 - ) + return torch.stack([cast(BertBlock, block).attn.b_V for block in self.blocks], dim=0) @property def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: """Stacks the attn output biases across all layers""" - return torch.stack( - [cast(BertBlock, block).attn.b_O for block in self.blocks], dim=0 - ) + return torch.stack([cast(BertBlock, block).attn.b_O for block in self.blocks], dim=0) @property def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: """Stacks the MLP input biases across all layers""" - return torch.stack( - [cast(BertBlock, block).mlp.b_in for block in self.blocks], dim=0 - ) + return torch.stack([cast(BertBlock, block).mlp.b_in for block in self.blocks], dim=0) @property def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: """Stacks the MLP output biases across all layers""" - return torch.stack( - [cast(BertBlock, block).mlp.b_out for block in self.blocks], dim=0 - ) + return torch.stack([cast(BertBlock, block).mlp.b_out for block in self.blocks], dim=0) @property def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] @@ -396,8 +362,4 @@ def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] def all_head_labels(self) -> List[str]: """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" - return [ - f"L{l}H{h}" - for l in range(self.cfg.n_layers) - for h in range(self.cfg.n_heads) - ] + return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index d0655c835..dc65c5c10 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -152,9 +152,7 @@ def __init__( else: # If no tokenizer name is provided, we assume we're training on an algorithmic task and # will pass in tokens directly. In this case, we don't need a tokenizer. - assert ( - self.cfg.d_vocab != -1 - ), "Must provide a tokenizer if d_vocab is not provided" + assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided" self.tokenizer = None if default_padding_side != "right": logging.warning( @@ -172,10 +170,7 @@ def __init__( self.hook_tokens = HookPoint() # [batch, pos] self.blocks = nn.ModuleList( - [ - TransformerBlock(self.cfg, block_index) - for block_index in range(self.cfg.n_layers) - ] + [TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)] ) if self.cfg.normalization_type == "RMS": @@ -197,9 +192,7 @@ def __init__( # If it's None, don't create either layer pass else: - logging.warning( - "Invalid normalization_type passed in %s", self.cfg.normalization_type - ) + logging.warning("Invalid normalization_type passed in %s", self.cfg.normalization_type) self.unembed = Unembed(self.cfg) if self.cfg.init_weights: @@ -250,9 +243,7 @@ def input_to_embed( self, input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, - padding_side: Optional[ - Union[Literal["left", "right"], None] - ] = USE_DEFAULT_VALUE, + padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, ) -> Tuple[ Float[torch.Tensor, "batch pos d_model"], # residual @@ -280,9 +271,7 @@ def input_to_embed( self.tokenizer is not None ), "Must provide a tokenizer if passing a string to the model" # This is only intended to support passing in a single string - tokens = self.to_tokens( - input, prepend_bos=prepend_bos, padding_side=padding_side - ) + tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) else: tokens = input if len(tokens.shape) == 1: @@ -291,18 +280,14 @@ def input_to_embed( if tokens.device.type != self.cfg.device: tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg)) - if ( - self.tokenizer and self.tokenizer.padding_side == "left" - ) or past_kv_cache is not None: - # If the padding side is left or we are using caching, we need to compute the attention mask - # for the adjustment of absolute positional embeddings and attention masking so that pad - # tokens are not attended. + if (self.tokenizer and self.tokenizer.padding_side == "left") or past_kv_cache is not None: + # If the padding side is left or we are using caching, we need to compute the attention + # mask for the adjustment of absolute positional embeddings and attention masking so + # that pad tokens are not attended. if prepend_bos is USE_DEFAULT_VALUE: prepend_bos = self.cfg.default_prepend_bos - attention_mask = utils.get_attention_mask( - self.tokenizer, tokens, prepend_bos - ) + attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos) if past_kv_cache is not None: # past_kv_cache is not None, so we're doing caching. @@ -375,14 +360,10 @@ def forward( return_type: Literal["logits"], loss_per_token: bool = False, prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, - padding_side: Optional[ - Union[Literal["left", "right"], None] - ] = USE_DEFAULT_VALUE, + padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, start_at_layer: Optional[int] = None, tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, - shortformer_pos_embed: Optional[ - Float[torch.Tensor, "batch pos d_model"] - ] = None, + shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, attention_mask: Optional[torch.Tensor] = None, # [batch pos] stop_at_layer: Optional[int] = None, past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, @@ -396,14 +377,10 @@ def forward( return_type: Literal["loss"], loss_per_token: bool = False, prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, - padding_side: Optional[ - Union[Literal["left", "right"], None] - ] = USE_DEFAULT_VALUE, + padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, start_at_layer: Optional[int] = None, tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, - shortformer_pos_embed: Optional[ - Float[torch.Tensor, "batch pos d_model"] - ] = None, + shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, attention_mask: Optional[torch.Tensor] = None, # [batch pos] stop_at_layer: Optional[int] = None, past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, @@ -417,14 +394,10 @@ def forward( return_type: Literal["both"], loss_per_token: bool = False, prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, - padding_side: Optional[ - Union[Literal["left", "right"], None] - ] = USE_DEFAULT_VALUE, + padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, start_at_layer: Optional[int] = None, tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, - shortformer_pos_embed: Optional[ - Float[torch.Tensor, "batch pos d_model"] - ] = None, + shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, attention_mask: Optional[torch.Tensor] = None, # [batch pos] stop_at_layer: Optional[int] = None, past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, @@ -438,14 +411,10 @@ def forward( return_type: Literal[None], loss_per_token: bool = False, prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, - padding_side: Optional[ - Union[Literal["left", "right"], None] - ] = USE_DEFAULT_VALUE, + padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, start_at_layer: Optional[int] = None, tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, - shortformer_pos_embed: Optional[ - Float[torch.Tensor, "batch pos d_model"] - ] = None, + shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, attention_mask: Optional[torch.Tensor] = None, # [batch pos] stop_at_layer: Optional[int] = None, past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, @@ -466,9 +435,7 @@ def forward( padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, start_at_layer: Optional[int] = None, tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, - shortformer_pos_embed: Optional[ - Float[torch.Tensor, "batch pos d_model"] - ] = None, + shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, attention_mask: Optional[torch.Tensor] = None, # [batch pos] stop_at_layer: Optional[int] = None, past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, @@ -580,9 +547,7 @@ def forward( residual, # Cache contains a list of HookedTransformerKeyValueCache objects, one for each # block - past_kv_cache_entry=( - past_kv_cache[i] if past_kv_cache is not None else None - ), + past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None, shortformer_pos_embed=shortformer_pos_embed, attention_mask=attention_mask, ) # [batch, pos, d_model] @@ -659,9 +624,7 @@ def run_with_cache( *model_args, remove_batch_dim=remove_batch_dim, **kwargs ) if return_cache_object: - cache = ActivationCache( - cache_dict, self, has_batch_dim=not remove_batch_dim - ) + cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) return out, cache else: return out, cache_dict @@ -718,9 +681,7 @@ def to_tokens( self, input: Union[str, List[str]], prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, - padding_side: Optional[ - Union[Literal["left", "right"], None] - ] = USE_DEFAULT_VALUE, + padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, move_to_device: bool = True, truncate: bool = True, ) -> Int[torch.Tensor, "batch pos"]: @@ -756,18 +717,14 @@ def to_tokens( with utils.LocallyOverridenDefaults( self, prepend_bos=prepend_bos, padding_side=padding_side ): - assert ( - self.tokenizer is not None - ), "Cannot use to_tokens without a tokenizer" + assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" assert ( self.cfg.tokenizer_prepends_bos is not None ), "Set the tokenizer for the model by calling set_tokenizer" if self.cfg.default_prepend_bos and not self.cfg.tokenizer_prepends_bos: # We want to prepend bos but the tokenizer doesn't automatically do it, so we add it manually - input = utils.get_input_with_manually_prepended_bos( - self.tokenizer, input - ) + input = utils.get_input_with_manually_prepended_bos(self.tokenizer, input) tokens = self.tokenizer( input, @@ -812,9 +769,7 @@ def to_string( # it's set, then tokenization is no longer invertible, and some tokens # with a bunch of whitespace get collapsed together if len(tokens.shape) == 2: - return self.tokenizer.batch_decode( - tokens, clean_up_tokenization_spaces=False - ) + return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) elif len(tokens.shape) <= 1: return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False) else: @@ -831,9 +786,7 @@ def to_str_tokens( list, ], prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, - padding_side: Optional[ - Union[Literal["left", "right"], None] - ] = USE_DEFAULT_VALUE, + padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, ) -> Union[List[str], List[List[str]]]: """Map text, a list of text or tokens to a list of tokens as strings. @@ -874,16 +827,14 @@ def to_str_tokens( if isinstance(input, list): return list( map( - lambda tokens: self.to_str_tokens( - tokens, prepend_bos, padding_side - ), + lambda tokens: self.to_str_tokens(tokens, prepend_bos, padding_side), input, ) ) # type: ignore elif isinstance(input, str): - tokens = self.to_tokens( - input, prepend_bos=prepend_bos, padding_side=padding_side - )[0] + tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[ + 0 + ] # Gemma tokenizer expects a batch dimension if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1: tokens = tokens.unsqueeze(1) @@ -907,9 +858,7 @@ def to_str_tokens( ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" else: raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}") - str_tokens = self.tokenizer.batch_decode( - tokens, clean_up_tokenization_spaces=False - ) + str_tokens = self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) return str_tokens def to_single_token(self, string): @@ -934,14 +883,10 @@ def to_single_str_token(self, int_token: int) -> str: def get_token_position( self, single_token: Union[str, int], - input: Union[ - str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]] - ], + input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]], mode="first", prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, - padding_side: Optional[ - Union[Literal["left", "right"], None] - ] = USE_DEFAULT_VALUE, + padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, ): """Get the position of a single_token in a string or sequence of tokens. @@ -972,9 +917,7 @@ def get_token_position( """ if isinstance(input, str): # If the input is a string, convert to tensor - tokens = self.to_tokens( - input, prepend_bos=prepend_bos, padding_side=padding_side - ) + tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) else: tokens = input @@ -991,9 +934,7 @@ def get_token_position( elif isinstance(single_token, torch.Tensor): single_token = single_token.item() - indices = torch.arange(len(tokens), device=tokens.device)[ - tokens == single_token - ] + indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token] assert len(indices) > 0, "The token does not occur in the prompt" if mode == "first": return indices[0].item() @@ -1085,12 +1026,8 @@ def move_model_modules_to_device(self): self.pos_embed.to(devices.get_device_for_block_index(0, self.cfg)) self.hook_pos_embed.to(devices.get_device_for_block_index(0, self.cfg)) if hasattr(self, "ln_final"): - self.ln_final.to( - devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg) - ) - self.unembed.to( - devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg) - ) + self.ln_final.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg)) + self.unembed.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg)) for i, block in enumerate(self.blocks): block.to(devices.get_device_for_block_index(i, self.cfg)) @@ -1265,9 +1202,7 @@ def from_pretrained( (from_pretrained_kwargs.get("torch_dtype", None) == torch.float16) or dtype == torch.float16 ) and device in ["cpu", None]: - logging.warning( - "float16 models may not work on CPU. Consider using a GPU or bfloat16." - ) + logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.") # Get the model name used in HuggingFace, rather than the alias. official_model_name = loading.get_official_model_name(model_name) @@ -1463,13 +1398,9 @@ def _init_weights_kaiming(self, dist_type="uniform"): for name, param in self.named_parameters(): if "W_" in name: if dist_type == "uniform": - init_kaiming_uniform_( - param, gain=gain, nonlinearity="relu", mode="fan_in" - ) + init_kaiming_uniform_(param, gain=gain, nonlinearity="relu", mode="fan_in") elif dist_type == "normal": - init_kaiming_normal_( - param, gain=gain, nonlinearity="relu", mode="fan_in" - ) + init_kaiming_normal_(param, gain=gain, nonlinearity="relu", mode="fan_in") def _init_weights_muP(self, dist_type="uniform"): """ @@ -1617,14 +1548,10 @@ def fold_layer_norm( # the bias, we use the W_ matrix to map it to the hidden space of the layer, so we need # to sum along axis -2, which is the residual stream space axis. if fold_biases: - state_dict[f"blocks.{l}.attn.b_Q"] = state_dict[ - f"blocks.{l}.attn.b_Q" - ] + ( + state_dict[f"blocks.{l}.attn.b_Q"] = state_dict[f"blocks.{l}.attn.b_Q"] + ( state_dict[f"blocks.{l}.attn.W_Q"] * state_dict[f"blocks.{l}.ln1.b"][None, :, None] - ).sum( - -2 - ) + ).sum(-2) state_dict[f"blocks.{l}.attn.{gqa}b_K"] = state_dict[ f"blocks.{l}.attn.{gqa}b_K" ] + ( @@ -1644,8 +1571,7 @@ def fold_layer_norm( del state_dict[f"blocks.{l}.ln1.b"] state_dict[f"blocks.{l}.attn.W_Q"] = ( - state_dict[f"blocks.{l}.attn.W_Q"] - * state_dict[f"blocks.{l}.ln1.w"][None, :, None] + state_dict[f"blocks.{l}.attn.W_Q"] * state_dict[f"blocks.{l}.ln1.w"][None, :, None] ) state_dict[f"blocks.{l}.attn.{gqa}W_K"] = ( state_dict[f"blocks.{l}.attn.{gqa}W_K"] @@ -1682,19 +1608,14 @@ def fold_layer_norm( # Fold ln2 into MLP if not self.cfg.attn_only: if fold_biases: - state_dict[f"blocks.{l}.mlp.b_in"] = state_dict[ - f"blocks.{l}.mlp.b_in" - ] + ( + state_dict[f"blocks.{l}.mlp.b_in"] = state_dict[f"blocks.{l}.mlp.b_in"] + ( state_dict[f"blocks.{l}.mlp.W_in"] * state_dict[f"blocks.{l}.ln2.b"][:, None] - ).sum( - -2 - ) + ).sum(-2) del state_dict[f"blocks.{l}.ln2.b"] state_dict[f"blocks.{l}.mlp.W_in"] = ( - state_dict[f"blocks.{l}.mlp.W_in"] - * state_dict[f"blocks.{l}.ln2.w"][:, None] + state_dict[f"blocks.{l}.mlp.W_in"] * state_dict[f"blocks.{l}.ln2.w"][:, None] ) if self.cfg.gated_mlp: @@ -1751,9 +1672,7 @@ def fold_layer_norm( ).sum(dim=-2) del state_dict[f"ln_final.b"] - state_dict[f"unembed.W_U"] = ( - state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None] - ) + state_dict[f"unembed.W_U"] = state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None] del state_dict[f"ln_final.w"] if center_weights: @@ -1771,30 +1690,28 @@ def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]): W_out. This is done by subtracting the mean of the weights from the weights themselves. This is done in-place. See fold_layer_norm for more details. """ - state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict[ - "embed.W_E" - ].mean(-1, keepdim=True) + state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean( + -1, keepdim=True + ) if self.cfg.positional_embedding_type != "rotary": state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[ "pos_embed.W_pos" ].mean(-1, keepdim=True) for l in range(self.cfg.n_layers): - state_dict[f"blocks.{l}.attn.W_O"] = state_dict[ + state_dict[f"blocks.{l}.attn.W_O"] = state_dict[f"blocks.{l}.attn.W_O"] - state_dict[ f"blocks.{l}.attn.W_O" - ] - state_dict[f"blocks.{l}.attn.W_O"].mean( + ].mean( -1, keepdim=True ) # W_O is [head_index, d_model, d_head] state_dict[f"blocks.{l}.attn.b_O"] = ( - state_dict[f"blocks.{l}.attn.b_O"] - - state_dict[f"blocks.{l}.attn.b_O"].mean() + state_dict[f"blocks.{l}.attn.b_O"] - state_dict[f"blocks.{l}.attn.b_O"].mean() ) # b_O is [d_model] if not self.cfg.attn_only: state_dict[f"blocks.{l}.mlp.W_out"] = state_dict[ f"blocks.{l}.mlp.W_out" ] - state_dict[f"blocks.{l}.mlp.W_out"].mean(-1, keepdim=True) state_dict[f"blocks.{l}.mlp.b_out"] = ( - state_dict[f"blocks.{l}.mlp.b_out"] - - state_dict[f"blocks.{l}.mlp.b_out"].mean() + state_dict[f"blocks.{l}.mlp.b_out"] - state_dict[f"blocks.{l}.mlp.b_out"].mean() ) return state_dict @@ -1807,12 +1724,10 @@ def center_unembed(self, state_dict: Dict[str, torch.Tensor]): how components contribute to the logits, we'll be less misled by components that just add something to every logit. """ - state_dict["unembed.W_U"] = state_dict["unembed.W_U"] - state_dict[ - "unembed.W_U" - ].mean(-1, keepdim=True) - state_dict["unembed.b_U"] = ( - state_dict["unembed.b_U"] - state_dict["unembed.b_U"].mean() + state_dict["unembed.W_U"] = state_dict["unembed.W_U"] - state_dict["unembed.W_U"].mean( + -1, keepdim=True ) + state_dict["unembed.b_U"] = state_dict["unembed.b_U"] - state_dict["unembed.b_U"].mean() return state_dict def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]): @@ -2089,9 +2004,7 @@ def generate( assert ( self.tokenizer is not None ), "Must provide a tokenizer if passing a string to the model" - tokens = self.to_tokens( - input, prepend_bos=prepend_bos, padding_side=padding_side - ) + tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) else: tokens = input @@ -2117,8 +2030,7 @@ def generate( assert self.tokenizer is not None if stop_at_eos: tokenizer_has_eos_token = ( - self.tokenizer is not None - and self.tokenizer.eos_token_id is not None + self.tokenizer is not None and self.tokenizer.eos_token_id is not None ) if eos_token_id is None: assert ( @@ -2134,15 +2046,11 @@ def generate( # eos_token_id is a Sequence (e.g. list or tuple) stop_tokens = eos_token_id eos_token_for_padding = ( - self.tokenizer.eos_token_id - if tokenizer_has_eos_token - else eos_token_id[0] + self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0] ) # An array to track which sequences in the batch have finished. - finished_sequences = torch.zeros( - batch_size, dtype=torch.bool, device=self.cfg.device - ) + finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) # Currently nothing in HookedTransformer changes with eval, but this is here in case # that changes in the future. @@ -2369,9 +2277,7 @@ def accumulated_bias( if include_mlp_biases: accumulated_bias += self.blocks[i].mlp.b_out if mlp_input: - assert ( - layer < self.cfg.n_layers - ), "Cannot include attn_bias from beyond the final layer" + assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer" accumulated_bias += self.blocks[layer].attn.b_O return accumulated_bias @@ -2405,20 +2311,14 @@ def all_composition_scores( # layer than the left head. mask = ( torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None] - < torch.arange(self.cfg.n_layers, device=self.cfg.device)[ - None, None, :, None - ] + < torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None] ) scores = torch.where(mask, scores, torch.zeros_like(scores)) return scores def all_head_labels(self): """Returns a list of all head names in the model.""" - return [ - f"L{l}H{h}" - for l in range(self.cfg.n_layers) - for h in range(self.cfg.n_heads) - ] + return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] def load_sample_training_dataset(self, **kwargs): """Load Sample Training Dataset. diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index 2ea815f0d..d41079bca 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -95,8 +95,8 @@ class HookedTransformerConfig: attn_only (bool): Whether to only use attention layers, no feedforward layers. Defaults to False seed (int, *optional*): The seed to use for the model. - Used to set sources of randomness (Python, PyTorch and - NumPy) and to initialize weights. Defaults to None. We recommend setting a seed, so your experiments are reproducible. + Used to set sources of randomness (Python, PyTorch and NumPy) and to initialize weights. + Defaults to None. We recommend setting a seed, so your experiments are reproducible. initializer_range (float): The standard deviation of the normal used to initialise the weights, initialized to 0.8 / sqrt(d_model). If weight_init_mode is 'xavier_uniform' or 'xavier_normal', this value is instead treated as the `gain` parameter for the weight @@ -217,11 +217,9 @@ def __post_init__(self): self.n_heads = self.d_model // self.d_head if not self.d_model % (self.d_head) == 0: - # logging.warning( - # f"d_model {self.d_model} is not divisible by d_head {self.d_head}. n_heads was inferred to be {self.n_heads}, rounding down the ratio." - # ) logging.warning( - "d_model %d is not divisible by d_head %d. n_heads was inferred to be %d, rounding down the ratio.", + "d_model %d is not divisible by d_head %d." + "n_heads was inferred to be %d, rounding down the ratio.", self.d_model, self.d_head, self.n_heads, @@ -230,19 +228,13 @@ def __post_init__(self): if self.seed is not None: self.set_seed_everywhere(self.seed) if self.use_local_attn: - assert ( - self.window_size is not None - ), "window_size must be specified for local attention" - assert ( - self.attn_types is not None - ), "attn_types must be specified for local attention" + assert self.window_size is not None, "window_size must be specified for local attention" + assert self.attn_types is not None, "attn_types must be specified for local attention" if not self.attn_only: if self.d_mlp is None: # For some reason everyone hard codes in this hyper-parameter! self.d_mlp: int = self.d_model * 4 - assert ( - self.act_fn is not None - ), "act_fn must be specified for non-attn-only models" + assert self.act_fn is not None, "act_fn must be specified for non-attn-only models" assert ( self.act_fn in SUPPORTED_ACTIVATIONS ), f"act_fn={self.act_fn} must be one of {SUPPORTED_ACTIVATIONS}" @@ -255,7 +247,8 @@ def __post_init__(self): if self.d_vocab_out == -1: # d_vocab_out defaults to d_vocab, unless there's an algorithmic task - # If d_vocab is not set, it'll be inferred from tokenizer_name or from a tokenizer explicitly passed to HookedTransformer initialisation. + # If d_vocab is not set, it'll be inferred from tokenizer_name or from a tokenizer + # explicitly passed to HookedTransformer initialisation. self.d_vocab_out = self.d_vocab if self.positional_embedding_type == "rotary" and self.rotary_dim is None: @@ -271,9 +264,7 @@ def __post_init__(self): ), "num_experts must be set if experts_per_token is set" # The number of parameters in attention layers (ignoring biases and layer norm). 4 because W_Q, W_K, W_V and W_O - self.n_params = self.n_layers * ( - (self.d_model * self.d_head * self.n_heads * 4) - ) + self.n_params = self.n_layers * ((self.d_model * self.d_head * self.n_heads * 4)) if not self.attn_only: assert self.d_mlp is not None # mypy # Number of parameters in MLP layers (ignoring biases and layer norm). 2 because W_in and W_out @@ -281,9 +272,7 @@ def __post_init__(self): if self.num_experts: # If we are using MoE, we multiply by num_experts, and add the expert gate parameters (d_model * num_experts) - mlp_params_per_layer = ( - mlp_params_per_layer + self.d_model - ) * self.num_experts + mlp_params_per_layer = (mlp_params_per_layer + self.d_model) * self.num_experts self.n_params += self.n_layers * mlp_params_per_layer if self.device is None: diff --git a/transformer_lens/SVDInterpreter.py b/transformer_lens/SVDInterpreter.py index c34c34e6f..cf0354d61 100644 --- a/transformer_lens/SVDInterpreter.py +++ b/transformer_lens/SVDInterpreter.py @@ -96,13 +96,9 @@ def plot_matrix(matrix, tokens, k=10, filter="topk"): _, _, V = torch.linalg.svd(matrix) else: - raise ValueError( - f"Vector type must be in {VECTOR_TYPES}, instead got {vector_type}" - ) + raise ValueError(f"Vector type must be in {VECTOR_TYPES}, instead got {vector_type}") - return self._get_singular_vectors_from_matrix( - V, self.params[OUTPUT_EMBEDDING], num_vectors - ) + return self._get_singular_vectors_from_matrix(V, self.params[OUTPUT_EMBEDDING], num_vectors) def _get_singular_vectors_from_matrix( self, diff --git a/transformer_lens/components.py b/transformer_lens/components.py index dc72da6b1..a3cd7f5f4 100644 --- a/transformer_lens/components.py +++ b/transformer_lens/components.py @@ -83,9 +83,7 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): if isinstance(cfg, Dict): cfg = HookedTransformerConfig.from_dict(cfg) self.cfg = cfg - self.W_pos = nn.Parameter( - torch.empty(self.cfg.n_ctx, self.cfg.d_model, dtype=cfg.dtype) - ) + self.W_pos = nn.Parameter(torch.empty(self.cfg.n_ctx, self.cfg.d_model, dtype=cfg.dtype)) def forward( self, @@ -119,9 +117,7 @@ def forward( # Separated from the no padding case for computational efficiency # (this code is a bit slower than the code above) - offset_position_ids = get_offset_position_ids( - past_kv_pos_offset, attention_mask - ) + offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask) pos_embed = self.W_pos[offset_position_ids] # [batch, pos, d_model] # Set the position embeddings to 0 for pad tokens (this is an arbitrary choice) @@ -148,9 +144,7 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): if isinstance(cfg, Dict): cfg = HookedTransformerConfig.from_dict(cfg) self.cfg = cfg - self.W_token_type = nn.Parameter( - torch.empty(2, self.cfg.d_model, dtype=cfg.dtype) - ) + self.W_token_type = nn.Parameter(torch.empty(2, self.cfg.d_model, dtype=cfg.dtype)) def forward(self, token_type_ids: Int[torch.Tensor, "batch pos"]): return self.W_token_type[token_type_ids, :] @@ -181,9 +175,7 @@ def forward( token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, ): base_index_id = torch.arange(input_ids.shape[1], device=input_ids.device) - index_ids = einops.repeat( - base_index_id, "pos -> batch pos", batch=input_ids.shape[0] - ) + index_ids = einops.repeat(base_index_id, "pos -> batch pos", batch=input_ids.shape[0]) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) @@ -193,9 +185,7 @@ def forward( self.token_type_embed(token_type_ids) ) - embeddings_out = ( - word_embeddings_out + position_embeddings_out + token_type_embeddings_out - ) + embeddings_out = word_embeddings_out + position_embeddings_out + token_type_embeddings_out layer_norm_out = self.ln(embeddings_out) return layer_norm_out @@ -273,9 +263,7 @@ def forward( class LayerNorm(nn.Module): - def __init__( - self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None - ): + def __init__(self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None): """ LayerNorm with optional length parameter @@ -342,15 +330,11 @@ def forward( scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() ) - return self.hook_normalized(x / scale).to( - self.cfg.dtype - ) # [batch, pos, length] + return self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length] class RMSNorm(nn.Module): - def __init__( - self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None - ): + def __init__(self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None): """ RMSNorm - LayerNorm without the centering and bias (RMS = Root Mean Square) @@ -409,20 +393,14 @@ def __init__( cfg = HookedTransformerConfig.from_dict(cfg) self.cfg = cfg self.W_Q = nn.Parameter( - torch.empty( - self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype - ) + torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype) ) self.W_K = abstract_attribute() self.W_V = abstract_attribute() self.W_O = nn.Parameter( - torch.empty( - self.cfg.n_heads, self.cfg.d_head, self.cfg.d_model, dtype=cfg.dtype - ) - ) - self.b_Q = nn.Parameter( - torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype) + torch.empty(self.cfg.n_heads, self.cfg.d_head, self.cfg.d_model, dtype=cfg.dtype) ) + self.b_Q = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)) self.b_K = abstract_attribute() self.b_V = abstract_attribute() self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype)) @@ -437,9 +415,7 @@ def __init__( elif self.attn_type == "local": # For local, this is banded, query - window_size < key <= query assert isinstance(self.cfg.window_size, int) - self.register_buffer( - "mask", torch.triu(causal_mask, 1 - self.cfg.window_size) - ) + self.register_buffer("mask", torch.triu(causal_mask, 1 - self.cfg.window_size)) else: raise ValueError(f"Invalid attention type: {self.attn_type}") @@ -547,9 +523,7 @@ def forward( kv_cache_pos_offset = 0 if self.cfg.positional_embedding_type == "rotary": - q = self.hook_rot_q( - self.apply_rotary(q, kv_cache_pos_offset, attention_mask) - ) + q = self.hook_rot_q(self.apply_rotary(q, kv_cache_pos_offset, attention_mask)) k = self.hook_rot_k( self.apply_rotary(k, 0, attention_mask) ) # keys are cached so no offset @@ -570,9 +544,7 @@ def forward( # only recompute when necessary to increase efficiency. if self.alibi is None or key_ctx > self.alibi.size(-1): - self.alibi = Attention.create_alibi_bias( - self.cfg.n_heads, key_ctx, self.cfg.device - ) + self.alibi = Attention.create_alibi_bias(self.cfg.n_heads, key_ctx, self.cfg.device) attn_scores += self.alibi[ :, :query_ctx, :key_ctx @@ -619,9 +591,7 @@ def forward( ) ) # [batch, pos, head_index, d_model] out = ( - einops.reduce( - result, "batch position index model->batch position model", "sum" - ) + einops.reduce(result, "batch position index model->batch position model", "sum") + self.b_O ) # [batch, pos, d_model] return out @@ -714,9 +684,7 @@ def calculate_z_scores( def apply_causal_mask( self, - attn_scores: Float[ - torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset" - ], + attn_scores: Float[torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset"], past_kv_pos_offset: int = 0, attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, ): @@ -731,9 +699,7 @@ def apply_causal_mask( ), f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug." # Index back to front to ensure local attention works - final_mask = self.mask[ - None, None, -query_ctx_length:, -key_ctx_length: - ] # [1, 1, pos, pos] + final_mask = self.mask[None, None, -query_ctx_length:, -key_ctx_length:] # [1, 1, pos, pos] if attention_mask is not None: # Apply a causal mask to the attention scores considering the padding einsum_str = "batch head pos offset_pos, batch offset_pos -> batch head pos offset_pos" @@ -749,9 +715,7 @@ def calculate_sin_cos_rotary( n_ctx: int, base: int = 10000, dtype: torch.dtype = torch.float32, - ) -> Tuple[ - Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"] - ]: + ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]: """ Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details @@ -814,9 +778,7 @@ def apply_rotary( ] x_rotated = x_rot * rotary_cos + x_flip * rotary_sin else: - offset_position_ids = get_offset_position_ids( - past_kv_pos_offset, attention_mask - ) + offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask) offset_position_ids = offset_position_ids.to(self.rotary_cos.device) mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :] mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :] @@ -940,14 +902,12 @@ def create_alibi_bias( The ALiBi bias that should be added to the attention scores before the softmax. """ # Create the slope matrix - slope: Float[torch.Tensor, "query key"] = Attention.create_alibi_slope( - n_ctx, device - ) + slope: Float[torch.Tensor, "query key"] = Attention.create_alibi_slope(n_ctx, device) # Create the scalar multiplier for each head. - multipliers: Float[ - torch.Tensor, "head_idx" - ] = Attention.create_alibi_multipliers(n_heads, device) + multipliers: Float[torch.Tensor, "head_idx"] = Attention.create_alibi_multipliers( + n_heads, device + ) # The ALiBi bias is then m * slope_matrix alibi_bias = torch.einsum("ij,k->kij", slope, multipliers) @@ -977,21 +937,13 @@ def __init__( cfg = HookedTransformerConfig.from_dict(cfg) self.cfg = cfg self.W_K = nn.Parameter( - torch.empty( - self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype - ) + torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype) ) self.W_V = nn.Parameter( - torch.empty( - self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype - ) - ) - self.b_K = nn.Parameter( - torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype) - ) - self.b_V = nn.Parameter( - torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype) + torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype) ) + self.b_K = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)) + self.b_V = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)) class GroupedQueryAttention(AbstractAttention): @@ -1185,13 +1137,9 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): cfg = HookedTransformerConfig.from_dict(cfg) self.cfg = cfg assert self.cfg.d_mlp is not None # TODO: should this not be optional? - self.W_in = nn.Parameter( - torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype) - ) + self.W_in = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype)) self.b_in = nn.Parameter(torch.zeros(self.cfg.d_mlp, dtype=cfg.dtype)) - self.W_out = nn.Parameter( - torch.empty(self.cfg.d_mlp, self.cfg.d_model, dtype=cfg.dtype) - ) + self.W_out = nn.Parameter(torch.empty(self.cfg.d_mlp, self.cfg.d_model, dtype=cfg.dtype)) self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype)) self.hook_pre = HookPoint() # [batch, pos, d_mlp] @@ -1224,8 +1172,7 @@ def forward( ) -> Float[torch.Tensor, "batch pos d_model"]: # Technically, all these einsums could be done with a single matmul, but this is more readable. pre_act = self.hook_pre( - einsum("batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in) - + self.b_in + einsum("batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in) + self.b_in ) # [batch, pos, d_mlp] if self.cfg.act_fn is not None and not self.cfg.act_fn.endswith("_ln"): post_act = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp] @@ -1264,16 +1211,10 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): cfg = HookedTransformerConfig.from_dict(cfg) self.cfg = cfg assert self.cfg.d_mlp is not None # keep mypy happy - self.W_in = nn.Parameter( - torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype) - ) - self.W_gate = nn.Parameter( - torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype) - ) + self.W_in = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype)) + self.W_gate = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype)) self.b_in = nn.Parameter(torch.zeros(self.cfg.d_mlp, dtype=cfg.dtype)) - self.W_out = nn.Parameter( - torch.empty(self.cfg.d_mlp, self.cfg.d_model, dtype=cfg.dtype) - ) + self.W_out = nn.Parameter(torch.empty(self.cfg.d_mlp, self.cfg.d_model, dtype=cfg.dtype)) self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype)) # hook on gate output but before act_fn @@ -1310,15 +1251,11 @@ def forward( ) -> Float[torch.Tensor, "batch pos d_model"]: # Technically, all these einsums could be done with a single matmul, but this is more readable. pre_act = self.hook_pre( - einsum( - "batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_gate - ) + einsum("batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_gate) ) # [batch, pos, d_mlp] if self.cfg.act_fn is not None and not self.cfg.act_fn.endswith("_ln"): pre_linear = self.hook_pre_linear( - einsum( - "batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in - ) + einsum("batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in) ) post_act = self.hook_post( (self.act_fn(pre_act) * pre_linear) + self.b_in @@ -1344,26 +1281,17 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): self.cfg = cfg # Ensure that num_experts and experts_per_token are specified and non-zero - assert ( - cfg.num_experts is not None - ), "num_experts must be specified for MoE layer" - assert ( - cfg.experts_per_token - ), "experts_per_token must be specified for MoE layer" + assert cfg.num_experts is not None, "num_experts must be specified for MoE layer" + assert cfg.experts_per_token, "experts_per_token must be specified for MoE layer" self.experts_per_token: int = cfg.experts_per_token assert ( cfg.experts_per_token <= cfg.num_experts ), "experts_per_token must be less than or equal to num_experts" self.experts = nn.ModuleList( - [ - GatedMLP(cfg) if cfg.gated_mlp else MLP(cfg) - for _ in range(cfg.num_experts) - ] - ) - self.W_gate = nn.Parameter( - torch.empty(cfg.d_model, cfg.num_experts, dtype=cfg.dtype) + [GatedMLP(cfg) if cfg.gated_mlp else MLP(cfg) for _ in range(cfg.num_experts)] ) + self.W_gate = nn.Parameter(torch.empty(cfg.d_model, cfg.num_experts, dtype=cfg.dtype)) # Hook on the weights of selected experts [batch pos experts_per_token] self.hook_expert_weights = HookPoint() @@ -1391,9 +1319,7 @@ def forward( # find the batch, pos, and expert indices which use this expert batch, pos, expert = torch.where(expert_indices == i) # accumulate the weighted outputs from the expert - results[batch] += weights[batch, pos, expert, None, None] * expert_mlp( - x[batch] - ) + results[batch] += weights[batch, pos, expert, None, None] * expert_mlp(x[batch]) return results @@ -1431,13 +1357,9 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig], block_index): if not self.cfg.attn_only: self.ln2 = nn.Identity() else: - logging.warning( - f"Invalid normalization_type passed in {self.cfg.normalization_type}" - ) + logging.warning(f"Invalid normalization_type passed in {self.cfg.normalization_type}") - attention = ( - Attention if self.cfg.n_key_value_heads is None else GroupedQueryAttention - ) + attention = Attention if self.cfg.n_key_value_heads is None else GroupedQueryAttention if not self.cfg.use_local_attn: self.attn = attention(cfg, "global", block_index) else: @@ -1469,9 +1391,7 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig], block_index): def forward( self, resid_pre: Float[torch.Tensor, "batch pos d_model"], - shortformer_pos_embed: Optional[ - Float[torch.Tensor, "batch pos d_model"] - ] = None, + shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, ) -> Float[torch.Tensor, "batch pos d_model"]: @@ -1538,39 +1458,25 @@ def add_head_dimension( ) ) # [batch, pos, d_model] if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp: - resid_mid = self.hook_resid_mid( - resid_pre + attn_out - ) # [batch, pos, d_model] + resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model] mlp_in = ( - resid_mid - if not self.cfg.use_hook_mlp_in - else self.hook_mlp_in(resid_mid.clone()) + resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone()) ) normalized_resid_mid = self.ln2(mlp_in) - mlp_out = self.hook_mlp_out( - self.mlp(normalized_resid_mid) - ) # [batch, pos, d_model] - resid_post = self.hook_resid_post( - resid_mid + mlp_out - ) # [batch, pos, d_model] + mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid)) # [batch, pos, d_model] + resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model] elif self.cfg.parallel_attn_mlp: # Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used. # In GPT-J, LN1 and LN2 are tied, in GPT-NeoX they aren't. normalized_resid_pre_2 = self.ln2( - resid_pre - if not self.cfg.use_hook_mlp_in - else self.hook_mlp_in(resid_pre.clone()) + resid_pre if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_pre.clone()) ) - mlp_out = self.hook_mlp_out( - self.mlp(normalized_resid_pre_2) - ) # [batch, pos, d_model] + mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_pre_2)) # [batch, pos, d_model] resid_post = self.hook_resid_post( resid_pre + attn_out + mlp_out ) # [batch, pos, d_model] else: - resid_post = self.hook_resid_post( - resid_pre + attn_out - ) # [batch, pos, d_model] + resid_post = self.hook_resid_post(resid_pre + attn_out) # [batch, pos, d_model] return resid_post @@ -1634,11 +1540,7 @@ def add_head_dimension(tensor): ) resid_mid = self.hook_resid_mid(resid_pre + attn_out) - mlp_in = ( - resid_mid - if not self.cfg.use_hook_mlp_in - else self.hook_mlp_in(resid_mid.clone()) - ) + mlp_in = resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone()) normalized_resid_mid = self.ln1(mlp_in) mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid)) resid_post = self.hook_resid_post(normalized_resid_mid + mlp_out) diff --git a/transformer_lens/evals.py b/transformer_lens/evals.py index 710560491..b77c727c5 100644 --- a/transformer_lens/evals.py +++ b/transformer_lens/evals.py @@ -40,9 +40,7 @@ def make_wiki_data_loader(tokenizer, batch_size=8): wiki_data = load_dataset("wikitext", "wikitext-2-v1", split="train") print(len(wiki_data)) dataset = utils.tokenize_and_concatenate(wiki_data, tokenizer) - data_loader = DataLoader( - dataset, batch_size=batch_size, shuffle=True, drop_last=True - ) + data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) return data_loader @@ -55,9 +53,7 @@ def make_owt_data_loader(tokenizer, batch_size=8): owt_data = load_dataset("stas/openwebtext-10k", split="train") print(len(owt_data)) dataset = utils.tokenize_and_concatenate(owt_data, tokenizer) - data_loader = DataLoader( - dataset, batch_size=batch_size, shuffle=True, drop_last=True - ) + data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) return data_loader @@ -71,9 +67,7 @@ def make_pile_data_loader(tokenizer, batch_size=8): pile_data = load_dataset("NeelNanda/pile-10k", split="train") print(len(pile_data)) dataset = utils.tokenize_and_concatenate(pile_data, tokenizer) - data_loader = DataLoader( - dataset, batch_size=batch_size, shuffle=True, drop_last=True - ) + data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) return data_loader @@ -86,12 +80,8 @@ def make_code_data_loader(tokenizer, batch_size=8): """ code_data = load_dataset("codeparrot/codeparrot-valid-v2-near-dedup", split="train") print(len(code_data)) - dataset = utils.tokenize_and_concatenate( - code_data, tokenizer, column_name="content" - ) - data_loader = DataLoader( - dataset, batch_size=batch_size, shuffle=True, drop_last=True - ) + dataset = utils.tokenize_and_concatenate(code_data, tokenizer, column_name="content") + data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) return data_loader @@ -146,9 +136,7 @@ def induction_loss( repeated_tokens[:, 0] = tokenizer.bos_token_id # Run the model, and extract the per token correct log prob logits = model(repeated_tokens, return_type="logits") - correct_log_probs = utils.lm_cross_entropy_loss( - logits, repeated_tokens, per_token=True - ) + correct_log_probs = utils.lm_cross_entropy_loss(logits, repeated_tokens, per_token=True) # Take the loss over the second half of the sequence return correct_log_probs[:, subseq_len + 1 :].mean() @@ -212,9 +200,7 @@ def __init__( self.tokenizer = tokenizer self.prepend_bos = prepend_bos - self.templates = ( - templates if templates is not None else self.get_default_templates() - ) + self.templates = templates if templates is not None else self.get_default_templates() self.names = names if names is not None else self.get_default_names() self.nouns = nouns if nouns is not None else self.get_default_nouns() @@ -256,9 +242,7 @@ def get_sample(self, symmetric=False) -> List[Dict[str, str]]: if symmetric: sample_2 = template.replace("[A]", names[1]) sample_2 = sample_2.replace("[B]", names[0]) - samples.append( - {"text": sample_2, "IO": " " + names[1], "S": " " + names[0]} - ) + samples.append({"text": sample_2, "IO": " " + names[1], "S": " " + names[0]}) return samples @@ -282,9 +266,7 @@ def get_default_nouns(): @torch.inference_mode() -def ioi_eval( - model, dataset=None, batch_size=8, num_samples=1000, tokenizer=None, symmetric=False -): +def ioi_eval(model, dataset=None, batch_size=8, num_samples=1000, tokenizer=None, symmetric=False): """Evaluate the Model on the Indirect Object Identification Task. Args: @@ -314,9 +296,7 @@ def collate(samples): "prompt_length": [p.shape[0] for p in prompts], } - data_loader = DataLoader( - dataset, batch_size=batch_size, shuffle=True, collate_fn=collate - ) + data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate) total_correct = 0 total_logit_diff = 0 diff --git a/transformer_lens/head_detector.py b/transformer_lens/head_detector.py index 13964c358..fbb50fae8 100644 --- a/transformer_lens/head_detector.py +++ b/transformer_lens/head_detector.py @@ -26,9 +26,7 @@ f"detection_pattern must be a Tensor or one of head names: {HEAD_NAMES}; got %s" ) -SEQ_LEN_ERR = ( - "The sequence must be non-empty and must fit within the model's context window." -) +SEQ_LEN_ERR = "The sequence must be non-empty and must fit within the model's context window." DET_PAT_NOT_SQUARE_ERR = "The detection pattern must be a lower triangular matrix of shape (sequence_length, sequence_length); sequence_length=%d; got detection patern of shape %s" @@ -113,9 +111,7 @@ def detect_head( # Validate detection pattern if it's a string if isinstance(detection_pattern, str): - assert detection_pattern in HEAD_NAMES, ( - INVALID_HEAD_NAME_ERR % detection_pattern - ) + assert detection_pattern in HEAD_NAMES, INVALID_HEAD_NAME_ERR % detection_pattern if isinstance(seq, list): batch_scores = [detect_head(model, seq, detection_pattern) for seq in seq] return torch.stack(batch_scores).mean(0) @@ -125,9 +121,7 @@ def detect_head( ).to(cfg.device) # if we're using "mul", detection_pattern should consist of zeros and ones - if error_measure == "mul" and not set(detection_pattern.unique().tolist()).issubset( - {0, 1} - ): + if error_measure == "mul" and not set(detection_pattern.unique().tolist()).issubset({0, 1}): logging.warning( "Using detection pattern with values other than 0 or 1 with error_measure 'mul'" ) @@ -142,9 +136,7 @@ def detect_head( _, cache = model.run_with_cache(tokens, remove_batch_dim=True) if heads is None: - layer2heads = { - layer_i: list(range(cfg.n_heads)) for layer_i in range(cfg.n_layers) - } + layer2heads = {layer_i: list(range(cfg.n_heads)) for layer_i in range(cfg.n_layers)} elif isinstance(heads, list): layer2heads = defaultdict(list) for layer, head in heads: @@ -200,9 +192,7 @@ def get_duplicate_token_head_detection_pattern( # If token_pattern[i][j] matches its transpose, then token j and token i are duplicates. eq_mask = np.equal(token_pattern, token_pattern.T).astype(int) - np.fill_diagonal( - eq_mask, 0 - ) # Current token is always a duplicate of itself. Ignore that. + np.fill_diagonal(eq_mask, 0) # Current token is always a duplicate of itself. Ignore that. detection_pattern = eq_mask.astype(int) return torch.tril(torch.as_tensor(detection_pattern).float()) diff --git a/transformer_lens/hook_points.py b/transformer_lens/hook_points.py index 831ee11b1..fb01e2a50 100644 --- a/transformer_lens/hook_points.py +++ b/transformer_lens/hook_points.py @@ -54,9 +54,7 @@ def __init__(self): def add_perma_hook(self, hook, dir="fwd") -> None: self.add_hook(hook, dir=dir, is_permanent=True) - def add_hook( - self, hook, dir="fwd", is_permanent=False, level=None, prepend=False - ) -> None: + def add_hook(self, hook, dir="fwd", is_permanent=False, level=None, prepend=False) -> None: """ Hook format is fn(activation, hook_name) Change it into PyTorch hook format (this includes input and output, @@ -111,9 +109,7 @@ def _remove_hooks(handles: List[LensHandle]) -> List[LensHandle]: for handle in handles: if including_permanent: handle.hook.remove() - elif (not handle.is_permanent) and ( - level is None or handle.context_level == level - ): + elif (not handle.is_permanent) and (level is None or handle.context_level == level): handle.hook.remove() else: output_handles.append(handle) @@ -190,13 +186,9 @@ def setup(self): def hook_points(self): return self.hook_dict.values() - def remove_all_hook_fns( - self, direction="both", including_permanent=False, level=None - ): + def remove_all_hook_fns(self, direction="both", including_permanent=False, level=None): for hp in self.hook_points(): - hp.remove_hooks( - direction, including_permanent=including_permanent, level=level - ) + hp.remove_hooks(direction, including_permanent=including_permanent, level=level) def clear_contexts(self): for hp in self.hook_points(): @@ -233,9 +225,7 @@ def check_and_add_hook( is_permanent=is_permanent, prepend=prepend, ) - hook_point.add_hook( - hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend - ) + hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend) def check_hooks_to_add( self, @@ -309,9 +299,7 @@ def hooks( for name, hook in fwd_hooks: if isinstance(name, str): - self.mod_dict[name].add_hook( - hook, dir="fwd", level=self.context_level - ) + self.mod_dict[name].add_hook(hook, dir="fwd", level=self.context_level) else: # Otherwise, name is a Boolean function on names for hook_name, hp in self.hook_dict.items(): @@ -319,9 +307,7 @@ def hooks( hp.add_hook(hook, dir="fwd", level=self.context_level) for name, hook in bwd_hooks: if isinstance(name, str): - self.mod_dict[name].add_hook( - hook, dir="bwd", level=self.context_level - ) + self.mod_dict[name].add_hook(hook, dir="bwd", level=self.context_level) else: # Otherwise, name is a Boolean function on names for hook_name, hp in self.hook_dict: # type: ignore @@ -370,9 +356,7 @@ def run_with_hooks( "WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur." ) - with self.hooks( - fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts - ) as hooked_model: + with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model: return hooked_model.forward(*model_args, **model_kwargs) def add_caching_hooks( diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index e417dd79c..d51295780 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -604,8 +604,7 @@ # Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases DEFAULT_MODEL_ALIASES = [ - MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name - for name in OFFICIAL_MODEL_NAMES + MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name for name in OFFICIAL_MODEL_NAMES ] NEED_REMOTE_CODE_MODELS = ( @@ -682,9 +681,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "final_rms": True, "gated_mlp": True, } - elif official_model_name.startswith( - "CodeLlama-7b" - ): # same architecture CodeLlama and Llama-2 + elif official_model_name.startswith("CodeLlama-7b"): # same architecture CodeLlama and Llama-2 cfg_dict = { "d_model": 4096, "d_head": 4096 // 32, @@ -1115,9 +1112,7 @@ def convert_neel_model_config(official_model_name: str, **kwargs): AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json. """ official_model_name = get_official_model_name(official_model_name) - cfg_json: dict = utils.download_file_from_hf( - official_model_name, "config.json", **kwargs - ) + cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs) cfg_arch = cfg_json.get( "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old" ) @@ -1371,9 +1366,7 @@ def get_pretrained_state_dict( )[0] else: file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0] - state_dict = utils.download_file_from_hf( - official_model_name, file_name, **kwargs - ) + state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs) # Convert to dtype state_dict = {k: v.to(dtype) for k, v in state_dict.items()} @@ -1400,14 +1393,10 @@ def get_pretrained_state_dict( **kwargs, ) else: - raise ValueError( - f"Checkpoints for model {official_model_name} are not supported" - ) + raise ValueError(f"Checkpoints for model {official_model_name} are not supported") elif hf_model is None: if official_model_name in NON_HF_HOSTED_MODEL_NAMES: - raise NotImplementedError( - "Model not hosted on HuggingFace, must pass in hf_model" - ) + raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model") elif "bert" in official_model_name: hf_model = BertForPreTraining.from_pretrained( official_model_name, torch_dtype=dtype, **kwargs @@ -1567,22 +1556,14 @@ def convert_neo_weights(neo, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.attn.W_K"] = W_K state_dict[f"blocks.{l}.attn.W_V"] = W_V - state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( - cfg.n_heads, cfg.d_head, dtype=cfg.dtype - ) - state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros( - cfg.n_heads, cfg.d_head, dtype=cfg.dtype - ) - state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros( - cfg.n_heads, cfg.d_head, dtype=cfg.dtype - ) + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) W_O = neo.transformer.h[l].attn.attention.out_proj.weight W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) state_dict[f"blocks.{l}.attn.W_O"] = W_O - state_dict[f"blocks.{l}.attn.b_O"] = neo.transformer.h[ - l - ].attn.attention.out_proj.bias + state_dict[f"blocks.{l}.attn.b_O"] = neo.transformer.h[l].attn.attention.out_proj.bias state_dict[f"blocks.{l}.ln2.w"] = neo.transformer.h[l].ln_2.weight state_dict[f"blocks.{l}.ln2.b"] = neo.transformer.h[l].ln_2.bias @@ -1619,15 +1600,9 @@ def convert_gptj_weights(gptj, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.attn.W_K"] = W_K state_dict[f"blocks.{l}.attn.W_V"] = W_V - state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( - cfg.n_heads, cfg.d_head, dtype=cfg.dtype - ) - state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros( - cfg.n_heads, cfg.d_head, dtype=cfg.dtype - ) - state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros( - cfg.n_heads, cfg.d_head, dtype=cfg.dtype - ) + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) W_O = gptj.transformer.h[l].attn.out_proj.weight W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) @@ -1689,30 +1664,16 @@ def convert_neox_weights(neox, cfg: HookedTransformerConfig): W_O = neox.gpt_neox.layers[l].attention.dense.weight W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) state_dict[f"blocks.{l}.attn.W_O"] = W_O - state_dict[f"blocks.{l}.attn.b_O"] = neox.gpt_neox.layers[ - l - ].attention.dense.bias + state_dict[f"blocks.{l}.attn.b_O"] = neox.gpt_neox.layers[l].attention.dense.bias - state_dict[f"blocks.{l}.ln2.w"] = neox.gpt_neox.layers[ - l - ].post_attention_layernorm.weight - state_dict[f"blocks.{l}.ln2.b"] = neox.gpt_neox.layers[ - l - ].post_attention_layernorm.bias + state_dict[f"blocks.{l}.ln2.w"] = neox.gpt_neox.layers[l].post_attention_layernorm.weight + state_dict[f"blocks.{l}.ln2.b"] = neox.gpt_neox.layers[l].post_attention_layernorm.bias - state_dict[f"blocks.{l}.mlp.W_in"] = neox.gpt_neox.layers[ - l - ].mlp.dense_h_to_4h.weight.T - state_dict[f"blocks.{l}.mlp.b_in"] = neox.gpt_neox.layers[ - l - ].mlp.dense_h_to_4h.bias + state_dict[f"blocks.{l}.mlp.W_in"] = neox.gpt_neox.layers[l].mlp.dense_h_to_4h.weight.T + state_dict[f"blocks.{l}.mlp.b_in"] = neox.gpt_neox.layers[l].mlp.dense_h_to_4h.bias - state_dict[f"blocks.{l}.mlp.W_out"] = neox.gpt_neox.layers[ - l - ].mlp.dense_4h_to_h.weight.T - state_dict[f"blocks.{l}.mlp.b_out"] = neox.gpt_neox.layers[ - l - ].mlp.dense_4h_to_h.bias + state_dict[f"blocks.{l}.mlp.W_out"] = neox.gpt_neox.layers[l].mlp.dense_4h_to_h.weight.T + state_dict[f"blocks.{l}.mlp.b_out"] = neox.gpt_neox.layers[l].mlp.dense_4h_to_h.bias state_dict["ln_final.w"] = neox.gpt_neox.final_layer_norm.weight state_dict["ln_final.b"] = neox.gpt_neox.final_layer_norm.bias @@ -1775,21 +1736,15 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig): cfg.d_model, dtype=cfg.dtype, device=cfg.device ) - state_dict[f"blocks.{l}.ln2.w"] = llama.model.layers[ - l - ].post_attention_layernorm.weight + state_dict[f"blocks.{l}.ln2.w"] = llama.model.layers[l].post_attention_layernorm.weight state_dict[f"blocks.{l}.mlp.W_in"] = llama.model.layers[l].mlp.up_proj.weight.T - state_dict[f"blocks.{l}.mlp.W_gate"] = llama.model.layers[ - l - ].mlp.gate_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_gate"] = llama.model.layers[l].mlp.gate_proj.weight.T state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros( cfg.d_mlp, dtype=cfg.dtype, device=cfg.device ) - state_dict[f"blocks.{l}.mlp.W_out"] = llama.model.layers[ - l - ].mlp.down_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_out"] = llama.model.layers[l].mlp.down_proj.weight.T state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros( cfg.d_model, dtype=cfg.dtype, device=cfg.device ) @@ -1797,9 +1752,7 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig): state_dict["ln_final.w"] = llama.model.norm.weight state_dict["unembed.W_U"] = llama.lm_head.weight.T - state_dict["unembed.b_U"] = torch.zeros( - cfg.d_vocab, dtype=cfg.dtype, device=cfg.device - ) + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device) return state_dict @@ -1814,9 +1767,7 @@ def convert_qwen_weights(qwen, cfg: HookedTransformerConfig): for l in range(cfg.n_layers): state_dict[f"blocks.{l}.ln1.w"] = model.h[l].ln_1.weight - W_Q, W_K, W_V = model.h[l].attn.c_attn.weight.split( - split_size=cfg.d_model, dim=0 - ) + W_Q, W_K, W_V = model.h[l].attn.c_attn.weight.split(split_size=cfg.d_model, dim=0) W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_heads) W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_heads) @@ -1922,19 +1873,13 @@ def convert_qwen2_weights(qwen, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) - state_dict[f"blocks.{l}.ln2.w"] = qwen.model.layers[ - l - ].post_attention_layernorm.weight + state_dict[f"blocks.{l}.ln2.w"] = qwen.model.layers[l].post_attention_layernorm.weight state_dict[f"blocks.{l}.mlp.W_in"] = qwen.model.layers[l].mlp.up_proj.weight.T - state_dict[f"blocks.{l}.mlp.W_gate"] = qwen.model.layers[ - l - ].mlp.gate_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_gate"] = qwen.model.layers[l].mlp.gate_proj.weight.T state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) - state_dict[f"blocks.{l}.mlp.W_out"] = qwen.model.layers[ - l - ].mlp.down_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_out"] = qwen.model.layers[l].mlp.down_proj.weight.T state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) state_dict["ln_final.w"] = qwen.model.norm.weight @@ -1967,9 +1912,7 @@ def convert_mistral_weights(mistral, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.attn._W_K"] = W_K state_dict[f"blocks.{l}.attn._W_V"] = W_V - state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( - cfg.n_heads, cfg.d_head, dtype=cfg.dtype - ) + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype ) @@ -1983,21 +1926,13 @@ def convert_mistral_weights(mistral, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) - state_dict[f"blocks.{l}.ln2.w"] = mistral.model.layers[ - l - ].post_attention_layernorm.weight + state_dict[f"blocks.{l}.ln2.w"] = mistral.model.layers[l].post_attention_layernorm.weight - state_dict[f"blocks.{l}.mlp.W_in"] = mistral.model.layers[ - l - ].mlp.up_proj.weight.T - state_dict[f"blocks.{l}.mlp.W_gate"] = mistral.model.layers[ - l - ].mlp.gate_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_in"] = mistral.model.layers[l].mlp.up_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_gate"] = mistral.model.layers[l].mlp.gate_proj.weight.T state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) - state_dict[f"blocks.{l}.mlp.W_out"] = mistral.model.layers[ - l - ].mlp.down_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_out"] = mistral.model.layers[l].mlp.down_proj.weight.T state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) state_dict["ln_final.w"] = mistral.model.norm.weight @@ -2033,9 +1968,7 @@ def convert_mixtral_weights(mixtral, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.attn._W_K"] = W_K state_dict[f"blocks.{l}.attn._W_V"] = W_V - state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( - cfg.n_heads, cfg.d_head, dtype=cfg.dtype - ) + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype ) @@ -2049,9 +1982,7 @@ def convert_mixtral_weights(mixtral, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) - state_dict[f"blocks.{l}.ln2.w"] = mixtral.model.layers[ - l - ].post_attention_layernorm.weight + state_dict[f"blocks.{l}.ln2.w"] = mixtral.model.layers[l].post_attention_layernorm.weight state_dict[f"blocks.{l}.mlp.W_gate"] = mixtral.model.layers[ l @@ -2069,9 +2000,7 @@ def convert_mixtral_weights(mixtral, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate"] = ( mixtral.model.layers[l].block_sparse_moe.experts[e].w1.weight.T ) - state_dict[f"blocks.{l}.mlp.experts.{e}.b_in"] = torch.zeros( - cfg.d_mlp, dtype=cfg.dtype - ) + state_dict[f"blocks.{l}.mlp.experts.{e}.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) state_dict[f"blocks.{l}.mlp.experts.{e}.W_out"] = ( mixtral.model.layers[l].block_sparse_moe.experts[e].w2.weight.T ) @@ -2094,12 +2023,8 @@ def convert_opt_weights(opt, cfg: HookedTransformerConfig): state_dict["pos_embed.W_pos"] = opt.model.decoder.embed_positions.weight[2:, :] for l in range(cfg.n_layers): - state_dict[f"blocks.{l}.ln1.w"] = opt.model.decoder.layers[ - l - ].self_attn_layer_norm.weight - state_dict[f"blocks.{l}.ln1.b"] = opt.model.decoder.layers[ - l - ].self_attn_layer_norm.bias + state_dict[f"blocks.{l}.ln1.w"] = opt.model.decoder.layers[l].self_attn_layer_norm.weight + state_dict[f"blocks.{l}.ln1.b"] = opt.model.decoder.layers[l].self_attn_layer_norm.bias W_Q = opt.model.decoder.layers[l].self_attn.q_proj.weight W_K = opt.model.decoder.layers[l].self_attn.k_proj.weight @@ -2154,16 +2079,10 @@ def convert_opt_weights(opt, cfg: HookedTransformerConfig): index=cfg.n_heads, ) state_dict[f"blocks.{l}.attn.W_O"] = W_O - state_dict[f"blocks.{l}.attn.b_O"] = opt.model.decoder.layers[ - l - ].self_attn.out_proj.bias + state_dict[f"blocks.{l}.attn.b_O"] = opt.model.decoder.layers[l].self_attn.out_proj.bias - state_dict[f"blocks.{l}.ln2.w"] = opt.model.decoder.layers[ - l - ].final_layer_norm.weight - state_dict[f"blocks.{l}.ln2.b"] = opt.model.decoder.layers[ - l - ].final_layer_norm.bias + state_dict[f"blocks.{l}.ln2.w"] = opt.model.decoder.layers[l].final_layer_norm.weight + state_dict[f"blocks.{l}.ln2.b"] = opt.model.decoder.layers[l].final_layer_norm.bias state_dict[f"blocks.{l}.mlp.W_in"] = opt.model.decoder.layers[l].fc1.weight.T state_dict[f"blocks.{l}.mlp.W_out"] = opt.model.decoder.layers[l].fc2.weight.T @@ -2253,9 +2172,7 @@ def convert_mingpt_weights(old_state_dict, cfg: HookedTransformerConfig): W_O = old_state_dict[f"blocks.{l}.attn.proj.weight"] W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) state_dict[f"blocks.{l}.attn.W_O"] = W_O - state_dict[f"blocks.{l}.attn.b_O"] = old_state_dict[ - f"blocks.{l}.attn.proj.bias" - ] + state_dict[f"blocks.{l}.attn.b_O"] = old_state_dict[f"blocks.{l}.attn.proj.bias"] state_dict[f"blocks.{l}.ln2.w"] = old_state_dict[f"blocks.{l}.ln2.weight"] state_dict[f"blocks.{l}.ln2.b"] = old_state_dict[f"blocks.{l}.ln2.bias"] @@ -2294,9 +2211,7 @@ def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig): new_state_dict["embed.W_E"] = old_state_dict["transformer.wte.weight"] new_state_dict["ln_final.w"] = old_state_dict["transformer.ln_f.weight"] - new_state_dict["ln_final.b"] = torch.zeros_like( - old_state_dict["transformer.ln_f.weight"] - ) + new_state_dict["ln_final.b"] = torch.zeros_like(old_state_dict["transformer.ln_f.weight"]) new_state_dict["unembed.W_U"] = old_state_dict["lm_head.weight"].T bias = False @@ -2307,16 +2222,12 @@ def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig): for layer in range(cfg.n_layers): layer_key = f"transformer.h.{layer}" - new_state_dict[f"blocks.{layer}.ln1.w"] = old_state_dict[ - f"{layer_key}.ln_1.weight" - ] + new_state_dict[f"blocks.{layer}.ln1.w"] = old_state_dict[f"{layer_key}.ln_1.weight"] # A bias of zeros is required for folding layer norm new_state_dict[f"blocks.{layer}.ln1.b"] = torch.zeros_like( old_state_dict[f"{layer_key}.ln_1.weight"] ) - new_state_dict[f"blocks.{layer}.ln2.w"] = old_state_dict[ - f"{layer_key}.ln_2.weight" - ] + new_state_dict[f"blocks.{layer}.ln2.w"] = old_state_dict[f"{layer_key}.ln_2.weight"] new_state_dict[f"blocks.{layer}.ln2.b"] = torch.zeros_like( old_state_dict[f"{layer_key}.ln_2.weight"] ) @@ -2342,12 +2253,8 @@ def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig): ].T if bias: - new_state_dict[f"blocks.{layer}.ln1.b"] = old_state_dict[ - f"{layer_key}.ln_1.bias" - ] - new_state_dict[f"blocks.{layer}.ln2.b"] = old_state_dict[ - f"{layer_key}.ln_2.bias" - ] + new_state_dict[f"blocks.{layer}.ln1.b"] = old_state_dict[f"{layer_key}.ln_1.bias"] + new_state_dict[f"blocks.{layer}.ln2.b"] = old_state_dict[f"{layer_key}.ln_2.bias"] new_state_dict[f"blocks.{layer}.mlp.b_in"] = old_state_dict[ f"{layer_key}.mlp.c_fc.bias" ] @@ -2465,32 +2372,20 @@ def convert_bloom_weights(bloom, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[:, 2, :] W_O = bloom.transformer.h[l].self_attention.dense.weight.T # [1024, 1024] - W_O = einops.rearrange( - W_O, "(n h) m->n h m", n=cfg.n_heads - ) # [n_heads, d_head, d_model] + W_O = einops.rearrange(W_O, "(n h) m->n h m", n=cfg.n_heads) # [n_heads, d_head, d_model] state_dict[f"blocks.{l}.attn.W_O"] = W_O - state_dict[f"blocks.{l}.attn.b_O"] = bloom.transformer.h[ - l - ].self_attention.dense.bias + state_dict[f"blocks.{l}.attn.b_O"] = bloom.transformer.h[l].self_attention.dense.bias - state_dict[f"blocks.{l}.ln2.w"] = bloom.transformer.h[ - l - ].post_attention_layernorm.weight - state_dict[f"blocks.{l}.ln2.b"] = bloom.transformer.h[ - l - ].post_attention_layernorm.bias + state_dict[f"blocks.{l}.ln2.w"] = bloom.transformer.h[l].post_attention_layernorm.weight + state_dict[f"blocks.{l}.ln2.b"] = bloom.transformer.h[l].post_attention_layernorm.bias W_in = bloom.transformer.h[l].mlp.dense_h_to_4h.weight.T state_dict[f"blocks.{l}.mlp.W_in"] = W_in - state_dict[f"blocks.{l}.mlp.b_in"] = bloom.transformer.h[ - l - ].mlp.dense_h_to_4h.bias + state_dict[f"blocks.{l}.mlp.b_in"] = bloom.transformer.h[l].mlp.dense_h_to_4h.bias W_out = bloom.transformer.h[l].mlp.dense_4h_to_h.weight.T state_dict[f"blocks.{l}.mlp.W_out"] = W_out - state_dict[f"blocks.{l}.mlp.b_out"] = bloom.transformer.h[ - l - ].mlp.dense_4h_to_h.bias + state_dict[f"blocks.{l}.mlp.b_out"] = bloom.transformer.h[l].mlp.dense_4h_to_h.bias state_dict["unembed.W_U"] = bloom.lm_head.weight.T state_dict["ln_final.w"] = bloom.transformer.ln_f.weight @@ -2585,15 +2480,9 @@ def convert_phi_weights(phi, cfg: HookedTransformerConfig): b_Q = phi.model.layers[l].self_attn.q_proj.bias b_K = phi.model.layers[l].self_attn.k_proj.bias b_V = phi.model.layers[l].self_attn.v_proj.bias - b_Q = einops.rearrange( - b_Q, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads - ) - b_K = einops.rearrange( - b_K, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads - ) - b_V = einops.rearrange( - b_V, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads - ) + b_Q = einops.rearrange(b_Q, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads) + b_K = einops.rearrange(b_K, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads) + b_V = einops.rearrange(b_V, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads) state_dict[f"blocks.{l}.attn.b_Q"] = b_Q state_dict[f"blocks.{l}.attn.b_K"] = b_K state_dict[f"blocks.{l}.attn.b_V"] = b_V @@ -2652,9 +2541,7 @@ def convert_gemma_weights(gemma, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.attn._W_K"] = W_K state_dict[f"blocks.{l}.attn._W_V"] = W_V - state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( - cfg.n_heads, cfg.d_head, dtype=cfg.dtype - ) + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype ) @@ -2676,14 +2563,10 @@ def convert_gemma_weights(gemma, cfg: HookedTransformerConfig): ) state_dict[f"blocks.{l}.mlp.W_in"] = gemma.model.layers[l].mlp.up_proj.weight.T - state_dict[f"blocks.{l}.mlp.W_gate"] = gemma.model.layers[ - l - ].mlp.gate_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_gate"] = gemma.model.layers[l].mlp.gate_proj.weight.T state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) - state_dict[f"blocks.{l}.mlp.W_out"] = gemma.model.layers[ - l - ].mlp.down_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_out"] = gemma.model.layers[l].mlp.down_proj.weight.T state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) # GemmaRMSNorm adds 1 to weights before multiplying by input @@ -2716,9 +2599,7 @@ def get_basic_config(model_name: str, **kwargs) -> Config: return Config( **{ k: v - for k, v in get_pretrained_model_config(model_name, **kwargs) - .to_dict() - .items() + for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items() if k in [ "d_model", diff --git a/transformer_lens/past_key_value_caching.py b/transformer_lens/past_key_value_caching.py index 0aca31278..2f904b927 100644 --- a/transformer_lens/past_key_value_caching.py +++ b/transformer_lens/past_key_value_caching.py @@ -27,9 +27,7 @@ def init_cache_entry( device: Union[torch.device, str, None], batch_size: int = 1, ): - n_heads = ( - cfg.n_key_value_heads if cfg.n_key_value_heads is not None else cfg.n_heads - ) + n_heads = cfg.n_key_value_heads if cfg.n_key_value_heads is not None else cfg.n_heads return cls( past_keys=torch.empty( (batch_size, 0, n_heads, cfg.d_head), device=device, dtype=cfg.dtype @@ -106,13 +104,9 @@ def unfreeze(self): for entry in self.entries: entry.frozen = False - def append_attention_mask( - self, attention_mask: Int[torch.Tensor, "batch new_tokens"] - ): + def append_attention_mask(self, attention_mask: Int[torch.Tensor, "batch new_tokens"]): attention_mask = attention_mask.to(self.previous_attention_mask.device) - updated_attention_mask = torch.cat( - [self.previous_attention_mask, attention_mask], dim=-1 - ) + updated_attention_mask = torch.cat([self.previous_attention_mask, attention_mask], dim=-1) if not self.frozen: self.previous_attention_mask = updated_attention_mask return updated_attention_mask diff --git a/transformer_lens/patching.py b/transformer_lens/patching.py index 189cfda2e..aff08dae0 100644 --- a/transformer_lens/patching.py +++ b/transformer_lens/patching.py @@ -79,11 +79,7 @@ def make_df_from_ranges( """ Takes in a list of column names and max ranges for each column, and returns a dataframe with the cartesian product of the range for each column (ie iterating through all combinations from zero to column_max_range - 1, in order, incrementing the final column first) """ - rows = list( - itertools.product( - *[range(axis_max_range) for axis_max_range in column_max_ranges] - ) - ) + rows = list(itertools.product(*[range(axis_max_range) for axis_max_range in column_max_ranges])) df = pd.DataFrame(rows, columns=column_names) return df @@ -98,9 +94,7 @@ def generic_activation_patch( model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, "batch pos"], clean_cache: ActivationCache, - patching_metric: Callable[ - [Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""] - ], + patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]], patch_setter: Callable[ [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation ], @@ -117,9 +111,7 @@ def generic_activation_patch( model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, "batch pos"], clean_cache: ActivationCache, - patching_metric: Callable[ - [Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""] - ], + patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]], patch_setter: Callable[ [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation ], @@ -135,9 +127,7 @@ def generic_activation_patch( model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, "batch pos"], clean_cache: ActivationCache, - patching_metric: Callable[ - [Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""] - ], + patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]], patch_setter: Callable[ [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation ], @@ -187,9 +177,7 @@ def generic_activation_patch( max_axis_range["head"] = max_axis_range["head_index"] # Get the max range for each axis we iterate over - index_axis_max_range = [ - max_axis_range[axis_name] for axis_name in index_axis_names - ] + index_axis_max_range = [max_axis_range[axis_name] for axis_name in index_axis_names] # Get the dataframe where each row is a tuple of indices index_df = make_df_from_ranges(index_axis_max_range, index_axis_names) @@ -206,9 +194,7 @@ def generic_activation_patch( if flattened_output: patched_metric_output = torch.zeros(len(index_df), device=model.cfg.device) else: - patched_metric_output = torch.zeros( - index_axis_max_range, device=model.cfg.device - ) + patched_metric_output = torch.zeros(index_axis_max_range, device=model.cfg.device) # A generic patching hook - for each index, it applies the patch_setter appropriately to patch the activation def patching_hook(corrupted_activation, hook, index, clean_activation): @@ -321,9 +307,7 @@ def layer_head_pos_pattern_patch_setter( """ assert len(index) == 3 layer, head_index, dest_pos = index - corrupted_activation[:, head_index, dest_pos, :] = clean_activation[ - :, head_index, dest_pos, : - ] + corrupted_activation[:, head_index, dest_pos, :] = clean_activation[:, head_index, dest_pos, :] return corrupted_activation @@ -684,9 +668,7 @@ def get_act_patch_attn_head_all_pos_every( """ act_patch_results: list[torch.Tensor] = [] act_patch_results.append( - get_act_patch_attn_head_out_all_pos( - model, corrupted_tokens, clean_cache, metric - ) + get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, metric) ) act_patch_results.append( get_act_patch_attn_head_q_all_pos(model, corrupted_tokens, clean_cache, metric) @@ -698,9 +680,7 @@ def get_act_patch_attn_head_all_pos_every( get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, metric) ) act_patch_results.append( - get_act_patch_attn_head_pattern_all_pos( - model, corrupted_tokens, clean_cache, metric - ) + get_act_patch_attn_head_pattern_all_pos(model, corrupted_tokens, clean_cache, metric) ) return torch.stack(act_patch_results, dim=0) @@ -737,9 +717,7 @@ def get_act_patch_attn_head_by_pos_every( pattern_results = get_act_patch_attn_head_pattern_by_pos( model, corrupted_tokens, clean_cache, metric ) - act_patch_results.append( - einops.rearrange(pattern_results, "batch head pos -> batch pos head") - ) + act_patch_results.append(einops.rearrange(pattern_results, "batch head pos -> batch pos head")) return torch.stack(act_patch_results, dim=0) @@ -758,13 +736,7 @@ def get_act_patch_block_every( patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [3, n_layers, pos] """ act_patch_results = [] - act_patch_results.append( - get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, metric) - ) - act_patch_results.append( - get_act_patch_attn_out(model, corrupted_tokens, clean_cache, metric) - ) - act_patch_results.append( - get_act_patch_mlp_out(model, corrupted_tokens, clean_cache, metric) - ) + act_patch_results.append(get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, metric)) + act_patch_results.append(get_act_patch_attn_out(model, corrupted_tokens, clean_cache, metric)) + act_patch_results.append(get_act_patch_mlp_out(model, corrupted_tokens, clean_cache, metric)) return torch.stack(act_patch_results, dim=0) diff --git a/transformer_lens/train.py b/transformer_lens/train.py index 946295619..450b58348 100644 --- a/transformer_lens/train.py +++ b/transformer_lens/train.py @@ -102,9 +102,7 @@ def train( optimizer = optim.SGD( model.parameters(), lr=config.lr, - weight_decay=( - config.weight_decay if config.weight_decay is not None else 0.0 - ), + weight_decay=(config.weight_decay if config.weight_decay is not None else 0.0), momentum=config.momentum, ) else: @@ -138,9 +136,7 @@ def train( samples += tokens.shape[0] if config.wandb: - wandb.log( - {"train_loss": loss.item(), "samples": samples, "epoch": epoch} - ) + wandb.log({"train_loss": loss.item(), "samples": samples, "epoch": epoch}) if config.print_every is not None and step % config.print_every == 0: print(f"Epoch {epoch} Samples {samples} Step {step} Loss {loss.item()}") diff --git a/transformer_lens/utilities/devices.py b/transformer_lens/utilities/devices.py index 27906ee55..40bc0fbd9 100644 --- a/transformer_lens/utilities/devices.py +++ b/transformer_lens/utilities/devices.py @@ -43,9 +43,7 @@ def get_device_for_block_index( def move_to_and_update_config( - model: Union[ - "transformer_lens.HookedTransformer", "transformer_lens.HookedEncoder" - ], + model: Union["transformer_lens.HookedTransformer", "transformer_lens.HookedEncoder"], device_or_dtype: Union[torch.device, str, torch.dtype], print_details=True, ): diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 3a343a7e4..719fa2377 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -31,15 +31,9 @@ USE_DEFAULT_VALUE = None -def select_compatible_kwargs( - kwargs_dict: Dict[str, Any], callable: Callable -) -> Dict[str, Any]: +def select_compatible_kwargs(kwargs_dict: Dict[str, Any], callable: Callable) -> Dict[str, Any]: """Return a dict with the elements kwargs_dict that are parameters of callable""" - return { - k: v - for k, v in kwargs_dict.items() - if k in inspect.getfullargspec(callable).args - } + return {k: v for k, v in kwargs_dict.items() if k in inspect.getfullargspec(callable).args} def download_file_from_hf( @@ -89,9 +83,7 @@ def clear_huggingface_cache(): def print_gpu_mem(step_name=""): - print( - f"{step_name} ~ {np.round(torch.cuda.memory_allocated()/2e30, 2)} GiB allocated on GPU." - ) + print(f"{step_name} ~ {np.round(torch.cuda.memory_allocated()/2e30, 2)} GiB allocated on GPU.") def get_corner(tensor, n=3): @@ -135,9 +127,7 @@ def lm_cross_entropy_loss( # Use torch.gather to find the log probs of the correct tokens # Offsets needed because we're predicting the NEXT token (this means the final logit is meaningless) # None and [..., 0] needed because the tensor used in gather must have the same rank. - predicted_log_probs = log_probs[..., :-1, :].gather( - dim=-1, index=tokens[..., 1:, None] - )[..., 0] + predicted_log_probs = log_probs[..., :-1, :].gather(dim=-1, index=tokens[..., 1:, None])[..., 0] if per_token: return -predicted_log_probs else: @@ -168,28 +158,17 @@ def gelu_new( return ( 0.5 * input - * ( - 1.0 - + torch.tanh( - np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0)) - ) - ) + * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) ) def gelu_fast( input: Float[torch.Tensor, "batch pos d_mlp"] ) -> Float[torch.Tensor, "batch pos d_mlp"]: - return ( - 0.5 - * input - * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) - ) + return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) -def solu( - input: Float[torch.Tensor, "batch pos d_mlp"] -) -> Float[torch.Tensor, "batch pos d_mlp"]: +def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]: """ SoLU activation function as described by https://transformer-circuits.pub/2022/solu/index.html. @@ -219,9 +198,7 @@ def calc_fan_in_and_fan_out(tensor): fan_in = shape[1] fan_out = shape[0] * shape[2] else: - raise ValueError( - f"Fan in and fan out can not be computed for shape {shape} tensors." - ) + raise ValueError(f"Fan in and fan out can not be computed for shape {shape} tensors.") return fan_in, fan_out @@ -329,14 +306,9 @@ def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]: # Divide into 20 chunks of ~ equal length num_chunks = 20 chunk_length = (len(full_text) - 1) // num_chunks + 1 - chunks = [ - full_text[i * chunk_length : (i + 1) * chunk_length] - for i in range(num_chunks) - ] + chunks = [full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks)] # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned - tokens = tokenizer(chunks, return_tensors="np", padding=True)[ - "input_ids" - ].flatten() + tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten() # Drop padding tokens tokens = tokens[tokens != tokenizer.pad_token_id] num_tokens = len(tokens) @@ -391,9 +363,7 @@ def sample_logits( final_logits = final_logits / temperature if freq_penalty > 0: - assert ( - tokens is not None - ), "Must provide input_tokens if applying a frequency penalty" + assert tokens is not None, "Must provide input_tokens if applying a frequency penalty" for batch_index in range(final_logits.shape[0]): # torch.bincount returns a tensor of length d_vocab, with the number of occurences of each token in the tokens. final_logits[batch_index] = final_logits[ @@ -412,9 +382,7 @@ def sample_logits( cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # We round up - we want prob >= top_p not top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ - ..., :-1 - ].clone() + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter( -1, sorted_indices, sorted_indices_to_remove @@ -592,11 +560,7 @@ def get_act_name( get_act_name('scale4ln1')=='blocks.4.ln1.hook_scale' get_act_name('pre5')=='blocks.5.mlp.hook_pre' """ - if ( - ("." in name or name.startswith("hook_")) - and layer is None - and layer_type is None - ): + if ("." in name or name.startswith("hook_")) and layer is None and layer_type is None: # If this was called on a full name, just return it return name match = re.match(r"([a-z]+)(\d+)([a-z]?.*)", name) @@ -657,9 +621,7 @@ def get_act_name( return full_act_name -def remove_batch_dim( - tensor: Float[torch.Tensor, "1 ..."] -) -> Float[torch.Tensor, "..."]: +def remove_batch_dim(tensor: Float[torch.Tensor, "1 ..."]) -> Float[torch.Tensor, "..."]: """ Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged """ @@ -859,9 +821,7 @@ def is_lower_triangular(x: torch.Tensor) -> bool: return x.equal(x.tril()) -def check_structure( - t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False -) -> None: +def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False) -> None: """Validate that the two square tensors have the same structure, i.e., that the directionality of comparisons points in the same directions both row-wise and column-wise. @@ -958,9 +918,7 @@ def get_cumsum_along_dim(tensor, dim, reverse=False): return cumsum -def get_attention_mask( - tokenizer, tokens: torch.Tensor, prepend_bos: bool -) -> torch.Tensor: +def get_attention_mask(tokenizer, tokens: torch.Tensor, prepend_bos: bool) -> torch.Tensor: """ Computes the attention mask for the tokenized input. NOTE: Only the leftmost leading pads (when `padding_side == left`) @@ -1072,8 +1030,7 @@ def __init__(self, model, **overrides): "padding_side": { "default_location": "model.tokenizer.padding_side", "valid_values": [USE_DEFAULT_VALUE, "left", "right"], - "skip_overriding": model.tokenizer - is None, # Do not override if tokenizer is None + "skip_overriding": model.tokenizer is None, # Do not override if tokenizer is None "default_value_to_restore": None, # Will be set later }, } @@ -1106,9 +1063,7 @@ def __enter__(self): info["default_value_to_restore"] = deepcopy(default_value) # Override the default value - locally_overriden_value = override_or_use_default_value( - default_value, override - ) + locally_overriden_value = override_or_use_default_value(default_value, override) set_nested_attr(self, default_location, locally_overriden_value) def __exit__(self, exc_type, exc_val, exc_tb): @@ -1196,16 +1151,12 @@ def get_tokens_with_bos_removed(tokenizer, tokens): if tokenizer.bos_token_id == tokenizer.pad_token_id: is_not_pad_token = tokens.ne(tokenizer.pad_token_id) - is_leading_pad = ( - get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0 - ) + is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0 real_bos_positions = is_leading_pad.sum(-1) - 1 else: real_bos_positions = (tokens == tokenizer.bos_token_id).int().argmax(-1) - tokens = tokens.scatter( - dim=1, index=real_bos_positions.unsqueeze(-1), value=-100 - ) + tokens = tokens.scatter(dim=1, index=real_bos_positions.unsqueeze(-1), value=-100) return tokens[tokens != -100].view(*bos_removed_shape) From b156ce169e8a0422d677e66ce6aca905f65f2bb6 Mon Sep 17 00:00:00 2001 From: Vasil Georgiev <149842188+VasilGeorgiev39@users.noreply.github.com> Date: Sat, 13 Apr 2024 19:49:52 +0100 Subject: [PATCH 54/73] Refactor hook_points (#505) * Refactor hook_points * restored remaining refactor * ran format * added partial registering again * restored prepend * added type comment again * fixed spacing --------- Co-authored-by: Bryce Meyer --- transformer_lens/hook_points.py | 85 ++++++++++++++++----------------- 1 file changed, 40 insertions(+), 45 deletions(-) diff --git a/transformer_lens/hook_points.py b/transformer_lens/hook_points.py index fb01e2a50..858c3f48b 100644 --- a/transformer_lens/hook_points.py +++ b/transformer_lens/hook_points.py @@ -7,7 +7,7 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import partial -from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union, cast +from typing import Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union, cast import torch.nn as nn import torch.utils.hooks as hooks @@ -54,54 +54,52 @@ def __init__(self): def add_perma_hook(self, hook, dir="fwd") -> None: self.add_hook(hook, dir=dir, is_permanent=True) - def add_hook(self, hook, dir="fwd", is_permanent=False, level=None, prepend=False) -> None: + def add_hook( + self, + hook: Callable, + dir: Literal["fwd", "bwd"] = "fwd", + is_permanent: bool = False, + level: Optional[int] = None, + prepend: bool = False, + ) -> None: """ Hook format is fn(activation, hook_name) Change it into PyTorch hook format (this includes input and output, which are the same for a HookPoint) If prepend is True, add this hook before all other hooks """ - if dir == "fwd": - def full_hook(module, module_input, module_output): - return hook(module_output, hook=self) + def full_hook(module, module_input, module_output): + if ( + dir == "bwd" + ): # For a backwards hook, module_output is a tuple of (grad,) - I don't know why. + module_output = module_output[0] + return hook(module_output, hook=self) - full_hook.__name__ = ( - hook.__repr__() - ) # annotate the `full_hook` with the string representation of the `hook` function + full_hook.__name__ = ( + hook.__repr__() + ) # annotate the `full_hook` with the string representation of the `hook` function + if dir == "fwd": pt_handle = self.register_forward_hook(full_hook) - handle = LensHandle(pt_handle, is_permanent, level) - - if prepend: - # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this... - self._forward_hooks.move_to_end(handle.hook.id, last=False) # type: ignore # TODO: this type error could signify a bug - self.fwd_hooks.insert(0, handle) - - else: - self.fwd_hooks.append(handle) - + _internal_hooks = self._forward_hooks + visible_hooks = self.fwd_hooks elif dir == "bwd": - # For a backwards hook, module_output is a tuple of (grad,) - I don't know why. - - def full_hook(module, module_input, module_output): - return hook(module_output[0], hook=self) + pt_handle = self.register_full_backward_hook(full_hook) + _internal_hooks = self._backward_hooks + visible_hooks = self.bwd_hooks + else: + raise ValueError(f"Invalid direction {dir}") - full_hook.__name__ = ( - hook.__repr__() - ) # annotate the `full_hook` with the string representation of the `hook` function + handle = LensHandle(pt_handle, is_permanent, level) - pt_handle = self.register_full_backward_hook(full_hook) - handle = LensHandle(pt_handle, is_permanent, level) + if prepend: + # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this... + _internal_hooks.move_to_end(handle.hook.id, last=False) # type: ignore # TODO: this type error could signify a bug + visible_hooks.insert(0, handle) - if prepend: - # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this... - self._backward_hooks.move_to_end(handle.hook.id, last=False) # type: ignore # TODO: this type error could signify a bug - self.bwd_hooks.insert(0, handle) - else: - self.bwd_hooks.append(handle) else: - raise ValueError(f"Invalid direction {dir}") + visible_hooks.append(handle) def remove_hooks(self, dir="fwd", including_permanent=False, level=None) -> None: def _remove_hooks(handles: List[LensHandle]) -> List[LensHandle]: @@ -396,23 +394,20 @@ def add_caching_hooks( self.is_caching = True - def save_hook(tensor, hook): - if remove_batch_dim: - cache[hook.name] = tensor.detach().to(device)[0] - else: - cache[hook.name] = tensor.detach().to(device) - - def save_hook_back(tensor, hook): + def save_hook(tensor, hook, is_backward): + hook_name = hook.name + if is_backward: + hook_name += "_grad" if remove_batch_dim: - cache[hook.name + "_grad"] = tensor.detach().to(device)[0] + cache[hook_name] = tensor.detach().to(device)[0] else: - cache[hook.name + "_grad"] = tensor.detach().to(device) + cache[hook_name] = tensor.detach().to(device) for name, hp in self.hook_dict.items(): if names_filter(name): - hp.add_hook(save_hook, "fwd") + hp.add_hook(partial(save_hook, is_backward=False), "fwd") if incl_bwd: - hp.add_hook(save_hook_back, "bwd") + hp.add_hook(partial(save_hook, is_backward=True), "bwd") return cache def run_with_cache( From 1553e81ee27c2b4b2fecff25f0dd36c7564e5ded Mon Sep 17 00:00:00 2001 From: Wes Gurnee <30759075+wesg52@users.noreply.github.com> Date: Sat, 13 Apr 2024 16:54:09 -0400 Subject: [PATCH 55/73] Fix split_qkv_input for grouped query attention (#520) * qkv initial fix * add test and update BertBlock * formatting changes * fix flaky gqa test * move helper function to utils * ran reformat --------- Co-authored-by: Bryce Meyer --- tests/unit/test_grouped_query_attention.py | 14 ++++ tests/unit/test_split_qkv.py | 73 ++++++++++++++++++++ transformer_lens/components.py | 77 +++++++++++----------- transformer_lens/utils.py | 17 +++++ 4 files changed, 143 insertions(+), 38 deletions(-) create mode 100644 tests/unit/test_split_qkv.py diff --git a/tests/unit/test_grouped_query_attention.py b/tests/unit/test_grouped_query_attention.py index 633bb9180..e5e603454 100644 --- a/tests/unit/test_grouped_query_attention.py +++ b/tests/unit/test_grouped_query_attention.py @@ -1,3 +1,4 @@ +import einops import torch from transformer_lens.components import Attention, GroupedQueryAttention @@ -78,3 +79,16 @@ def test_grouped_query_attention_output_is_correct(): grouped_query_attn_output = grouped_query_attention(query_input, key_input, value_input) assert torch.equal(regular_attn_output, grouped_query_attn_output) + + # Test GQA behaves correctly when use_split_qkv_input is True + grouped_query_attention.cfg.use_split_qkv_input = True + + split_query_input = einops.repeat(query_input, "b n d -> b n h d", h=n_heads).clone() + split_key_input = einops.repeat(key_input, "b n d -> b n h d", h=n_key_value_heads).clone() + split_value_input = einops.repeat(value_input, "b n d -> b n h d", h=n_key_value_heads).clone() + + split_grouped_query_attn_output = grouped_query_attention( + split_query_input, split_key_input, split_value_input + ) + + assert torch.allclose(regular_attn_output, split_grouped_query_attn_output, rtol=1e-6) diff --git a/tests/unit/test_split_qkv.py b/tests/unit/test_split_qkv.py new file mode 100644 index 000000000..e8586305b --- /dev/null +++ b/tests/unit/test_split_qkv.py @@ -0,0 +1,73 @@ +import torch + +from transformer_lens import HookedTransformer +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def test_split_qkv_normal_attn_correct(): + """Verifies that the split_qkv_input flag does not change the output for models with normal attention.""" + d_model = 128 + d_head = 8 + n_heads = 16 + n_ctx = 128 + n_layers = 1 + d_vocab = 10 + + cfg = HookedTransformerConfig( + d_model=d_model, + d_head=d_head, + n_heads=n_heads, + n_ctx=n_ctx, + n_layers=n_layers, + attn_only=True, + d_vocab=d_vocab, + ) + + model = HookedTransformer(cfg) + assert model.cfg.use_split_qkv_input is False + + x = torch.arange(1, 9).unsqueeze(0) + normal_output = model(x) + + model.set_use_split_qkv_input(True) + assert model.cfg.use_split_qkv_input is True + + split_output = model(x) + + assert torch.allclose(normal_output, split_output, atol=1e-6) + + +def test_split_qkv_grouped_query_attn_correct(): + """Verifies that the split_qkv_input flag does not change the output for models with grouped query attention.""" + + d_model = 128 + d_head = 8 + n_heads = 16 + n_ctx = 128 + n_key_value_heads = 2 + n_layers = 1 + d_vocab = 10 + + cfg = HookedTransformerConfig( + d_model=d_model, + d_head=d_head, + n_heads=n_heads, + n_ctx=n_ctx, + n_key_value_heads=n_key_value_heads, + n_layers=n_layers, + attn_only=True, + d_vocab=d_vocab, + ) + + model = HookedTransformer(cfg) + assert model.cfg.use_split_qkv_input is False + + x = torch.arange(1, 9).unsqueeze(0) + normal_output = model(x) + + model.set_use_split_qkv_input(True) + assert model.cfg.use_split_qkv_input is True + + split_output = model(x) + + assert torch.allclose(normal_output, split_output, atol=1e-6) diff --git a/transformer_lens/components.py b/transformer_lens/components.py index a3cd7f5f4..b61d5bc0b 100644 --- a/transformer_lens/components.py +++ b/transformer_lens/components.py @@ -22,7 +22,13 @@ from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry -from transformer_lens.utils import gelu_fast, gelu_new, get_offset_position_ids, solu +from transformer_lens.utils import ( + gelu_fast, + gelu_new, + get_offset_position_ids, + repeat_along_head_dimension, + solu, +) # Embed & Unembed @@ -496,10 +502,12 @@ def forward( key_input: Union[ Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos head_index d_model"], + Float[torch.Tensor, "batch pos kv_head_index d_model"], ], value_input: Union[ Float[torch.Tensor, "batch pos d_model"], Float[torch.Tensor, "batch pos head_index d_model"], + Float[torch.Tensor, "batch pos kv_head_index d_model"], ], past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None, @@ -1056,13 +1064,14 @@ def calculate_qkv_matrices( A tuple containing the Q, K, and V matrices with the specified shapes. """ if self.cfg.use_split_qkv_input or self.cfg.use_attn_in: - qkv_einops_string = "batch pos kv_head_index d_model" + kv_einops_string = "batch pos kv_head_index d_model" + q_einops_string = "batch pos head_index d_model" else: - qkv_einops_string = "batch pos d_model" + kv_einops_string = q_einops_string = "batch pos d_model" q = self.hook_q( einsum( - f"{qkv_einops_string}, head_index d_model d_head \ + f"{q_einops_string}, head_index d_model d_head \ -> batch pos head_index d_head", query_input, self.W_Q, @@ -1071,7 +1080,7 @@ def calculate_qkv_matrices( ) # [batch, pos, head_index, d_head] k = self.hook_k( einsum( - f"{qkv_einops_string}, kv_head_index d_model d_head \ + f"{kv_einops_string}, kv_head_index d_model d_head \ -> batch pos kv_head_index d_head", key_input, self._W_K, @@ -1080,7 +1089,7 @@ def calculate_qkv_matrices( ) # [batch, pos, head_index, d_head] v = self.hook_v( einsum( - f"{qkv_einops_string}, kv_head_index d_model d_head \ + f"{kv_einops_string}, kv_head_index d_model d_head \ -> batch pos kv_head_index d_head", value_input, self._W_V, @@ -1408,36 +1417,35 @@ def forward( """ resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model] - def add_head_dimension( - tensor: Float[torch.Tensor, "batch pos d_model"], - clone_tensor=True, - # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry - ): - repeated_tensor = einops.repeat( - tensor, - "batch pos d_model -> batch pos n_heads d_model", - n_heads=self.cfg.n_heads, - ) - if clone_tensor: - return repeated_tensor.clone() - else: - return repeated_tensor - if self.cfg.use_attn_in or self.cfg.use_split_qkv_input: # We're adding a head dimension - attn_in = add_head_dimension(resid_pre, clone_tensor=False) if shortformer_pos_embed is not None: - shortformer_pos_embed = add_head_dimension(shortformer_pos_embed) + shortformer_pos_embed = repeat_along_head_dimension( + shortformer_pos_embed, n_heads=self.cfg.n_heads + ) else: attn_in = resid_pre if self.cfg.use_attn_in: - attn_in = self.hook_attn_in(attn_in.clone()) + attn_in = self.hook_attn_in( + repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads) + ) if self.cfg.use_split_qkv_input: - query_input = self.hook_q_input(attn_in.clone()) - key_input = self.hook_k_input(attn_in.clone()) - value_input = self.hook_v_input(attn_in.clone()) + n_kv_heads = ( + self.cfg.n_key_value_heads + if self.cfg.n_key_value_heads is not None + else self.cfg.n_heads + ) + query_input = self.hook_q_input( + repeat_along_head_dimension(resid_pre, n_heads=self.cfg.n_heads) + ) + key_input = self.hook_k_input( + repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads) + ) + value_input = self.hook_v_input( + repeat_along_head_dimension(resid_pre, n_heads=n_kv_heads) + ) else: query_input = attn_in key_input = attn_in @@ -1518,17 +1526,10 @@ def forward( value_input = resid_pre if self.cfg.use_split_qkv_input: - - def add_head_dimension(tensor): - return einops.repeat( - tensor, - "batch pos d_model -> batch pos n_heads d_model", - n_heads=self.cfg.n_heads, - ).clone() - - query_input = self.hook_q_input(add_head_dimension(query_input)) - key_input = self.hook_k_input(add_head_dimension(key_input)) - value_input = self.hook_v_input(add_head_dimension(value_input)) + n_heads = self.cfg.n_heads + query_input = self.hook_q_input(repeat_along_head_dimension(query_input, n_heads)) + key_input = self.hook_k_input(repeat_along_head_dimension(key_input, n_heads)) + value_input = self.hook_v_input(repeat_along_head_dimension(value_input, n_heads)) attn_out = self.hook_attn_out( self.attn( diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 719fa2377..1b33f99fa 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -957,6 +957,23 @@ def get_attention_mask(tokenizer, tokens: torch.Tensor, prepend_bos: bool) -> to return attention_mask +def repeat_along_head_dimension( + tensor: Float[torch.Tensor, "batch pos d_model"], + n_heads: int, + clone_tensor=True, + # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry +): + repeated_tensor = einops.repeat( + tensor, + "batch pos d_model -> batch pos n_heads d_model", + n_heads=n_heads, + ) + if clone_tensor: + return repeated_tensor.clone() + else: + return repeated_tensor + + def get_nested_attr(obj, attr_str): """ Retrieves a nested attribute from an object based on a dot-separated string. From 3aaac2ed1ad04b8e408476ed5c7390029816a9d9 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Tue, 16 Apr 2024 01:03:00 +0200 Subject: [PATCH 56/73] locked attribution patching to 1.1.1 (#541) --- demos/Attribution_Patching_Demo.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demos/Attribution_Patching_Demo.ipynb b/demos/Attribution_Patching_Demo.ipynb index 284b75412..cef67eb8b 100644 --- a/demos/Attribution_Patching_Demo.ipynb +++ b/demos/Attribution_Patching_Demo.ipynb @@ -1 +1 @@ -{"cells":[{"cell_type":"markdown","metadata":{},"source":["\n"," \"Open\n",""]},{"cell_type":"markdown","metadata":{},"source":[" # Attribution Patching Demo\n"," **Read [the accompanying blog post here](https://neelnanda.io/attribution-patching) for more context**\n"," This is an interim research report, giving a whirlwind tour of some unpublished work I did at Anthropic (credit to the then team - Chris Olah, Catherine Olsson, Nelson Elhage and Tristan Hume for help, support, and mentorship!)\n","\n"," The goal of this work is run activation patching at an industrial scale, by using gradient based attribution to approximate the technique - allow an arbitrary number of patches to be made on two forwards and a single backward pass\n","\n"," I have had less time than hoped to flesh out this investigation, but am writing up a rough investigation and comparison to standard activation patching on a few tasks to give a sense of the potential of this approach, and where it works vs falls down."]},{"cell_type":"markdown","metadata":{},"source":[" To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n","\n"," **Tips for reading this Colab:**\n"," * You can run all this code for yourself!\n"," * The graphs are interactive!\n"," * Use the table of contents pane in the sidebar to navigate\n"," * Collapse irrelevant sections with the dropdown arrows\n"," * Search the page using the search in the sidebar, not CTRL+F"]},{"cell_type":"markdown","metadata":{},"source":[" ## Setup (Ignore)"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Running as a Jupyter notebook - intended for development only!\n"]}],"source":["# Janky code to do different setup when run in a Colab notebook vs VSCode\n","DEBUG_MODE = False\n","try:\n"," import google.colab\n"," IN_COLAB = True\n"," print(\"Running as a Colab notebook\")\n"," %pip install transformer_lens\n"," %pip install torchtyping\n"," # Install my janky personal plotting utils\n"," %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n"," # Install another version of node that makes PySvelte work way faster\n"," !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n"," %pip install git+https://github.com/neelnanda-io/PySvelte.git\n"," # Needed for PySvelte to work, v3 came out and broke things...\n"," %pip install typeguard==2.13.3\n","except:\n"," IN_COLAB = False\n"," print(\"Running as a Jupyter notebook - intended for development only!\")\n"," from IPython import get_ipython\n","\n"," ipython = get_ipython()\n"," # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n"," ipython.magic(\"load_ext autoreload\")\n"," ipython.magic(\"autoreload 2\")"]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[],"source":["# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n","import plotly.io as pio\n","\n","if IN_COLAB or not DEBUG_MODE:\n"," # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n"," pio.renderers.default = \"colab\"\n","else:\n"," pio.renderers.default = \"notebook_connected\""]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[],"source":["# Import stuff\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","import numpy as np\n","import einops\n","from fancy_einsum import einsum\n","import tqdm.notebook as tqdm\n","import random\n","from pathlib import Path\n","import plotly.express as px\n","from torch.utils.data import DataLoader\n","\n","from torchtyping import TensorType as TT\n","from typing import List, Union, Optional, Callable\n","from functools import partial\n","import copy\n","import itertools\n","import json\n","\n","from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n","import dataclasses\n","import datasets\n","from IPython.display import HTML, Markdown"]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[],"source":["import pysvelte\n","\n","import transformer_lens\n","import transformer_lens.utils as utils\n","from transformer_lens.hook_points import (\n"," HookedRootModule,\n"," HookPoint,\n",") # Hooking utilities\n","from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache"]},{"cell_type":"markdown","metadata":{},"source":[" Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[],"source":["from neel_plotly import line, imshow, scatter"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":["import transformer_lens.patching as patching"]},{"cell_type":"markdown","metadata":{},"source":[" ## IOI Patching Setup\n"," This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important."]},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-small into HookedTransformer\n"]}],"source":["model = HookedTransformer.from_pretrained(\"gpt2-small\")\n","model.set_use_attn_result(True)"]},{"cell_type":"code","execution_count":9,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to\n","Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to\n","Answer token indices tensor([[ 5335, 1757],\n"," [ 1757, 5335],\n"," [ 4186, 3700],\n"," [ 3700, 4186],\n"," [ 6035, 15686],\n"," [15686, 6035],\n"," [ 5780, 14235],\n"," [14235, 5780]], device='cuda:0')\n"]}],"source":["prompts = ['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']\n","answers = [(' Mary', ' John'), (' John', ' Mary'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]\n","\n","clean_tokens = model.to_tokens(prompts)\n","# Swap each adjacent pair, with a hacky list comprehension\n","corrupted_tokens = clean_tokens[\n"," [(i+1 if i%2==0 else i-1) for i in range(len(clean_tokens)) ]\n"," ]\n","print(\"Clean string 0\", model.to_string(clean_tokens[0]))\n","print(\"Corrupted string 0\", model.to_string(corrupted_tokens[0]))\n","\n","answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=model.cfg.device)\n","print(\"Answer token indices\", answer_token_indices)"]},{"cell_type":"code","execution_count":10,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 3.5519\n","Corrupted logit diff: -3.5519\n"]}],"source":["def get_logit_diff(logits, answer_token_indices=answer_token_indices):\n"," if len(logits.shape)==3:\n"," # Get final logits only\n"," logits = logits[:, -1, :]\n"," correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))\n"," incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))\n"," return (correct_logits - incorrect_logits).mean()\n","\n","clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n","corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n","\n","clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()\n","print(f\"Clean logit diff: {clean_logit_diff:.4f}\")\n","\n","corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()\n","print(f\"Corrupted logit diff: {corrupted_logit_diff:.4f}\")"]},{"cell_type":"code","execution_count":11,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Baseline is 1: 1.0000\n","Corrupted Baseline is 0: 0.0000\n"]}],"source":["CLEAN_BASELINE = clean_logit_diff\n","CORRUPTED_BASELINE = corrupted_logit_diff\n","def ioi_metric(logits, answer_token_indices=answer_token_indices):\n"," return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE - CORRUPTED_BASELINE)\n","\n","print(f\"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}\")\n","print(f\"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Patching\n"," In the following cells, we define attribution patching and use it in various ways on the model."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["Metric = Callable[[TT[\"batch_and_pos_dims\", \"d_model\"]], float]"]},{"cell_type":"code","execution_count":13,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Value: 1.0\n","Clean Activations Cached: 220\n","Clean Gradients Cached: 220\n","Corrupted Value: 0.0\n","Corrupted Activations Cached: 220\n","Corrupted Gradients Cached: 220\n"]}],"source":["filter_not_qkv_input = lambda name: \"_input\" not in name\n","def get_cache_fwd_and_bwd(model, tokens, metric):\n"," model.reset_hooks()\n"," cache = {}\n"," def forward_cache_hook(act, hook):\n"," cache[hook.name] = act.detach()\n"," model.add_hook(filter_not_qkv_input, forward_cache_hook, \"fwd\")\n","\n"," grad_cache = {}\n"," def backward_cache_hook(act, hook):\n"," grad_cache[hook.name] = act.detach()\n"," model.add_hook(filter_not_qkv_input, backward_cache_hook, \"bwd\")\n","\n"," value = metric(model(tokens))\n"," value.backward()\n"," model.reset_hooks()\n"," return value.item(), ActivationCache(cache, model), ActivationCache(grad_cache, model)\n","\n","clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(model, clean_tokens, ioi_metric)\n","print(\"Clean Value:\", clean_value)\n","print(\"Clean Activations Cached:\", len(clean_cache))\n","print(\"Clean Gradients Cached:\", len(clean_grad_cache))\n","corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(model, corrupted_tokens, ioi_metric)\n","print(\"Corrupted Value:\", corrupted_value)\n","print(\"Corrupted Activations Cached:\", len(corrupted_cache))\n","print(\"Corrupted Gradients Cached:\", len(corrupted_grad_cache))"]},{"cell_type":"markdown","metadata":{},"source":[" ### Attention Attribution\n"," The easiest thing to start with is to not even engage with the corrupted tokens/patching, but to look at the attribution of the attention patterns - that is, the linear approximation to what happens if you set each element of the attention pattern to zero. This, as it turns out, is a good proxy to what is going on with each head!\n"," Note that this is *not* the same as what we will later do with patching. In particular, this does not set up a careful counterfactual! It's a good tool for what's generally going on in this problem, but does not control for eg stuff that systematically boosts John > Mary in general, stuff that says \"I should activate the IOI circuit\", etc. Though using logit diff as our metric *does*\n"," Each element of the batch is independent and the metric is an average logit diff, so we can analyse each batch element independently here. We'll look at the first one, and then at the average across the whole batch (note - 4 prompts have indirect object before subject, 4 prompts have it the other way round, making the average pattern harder to interpret - I plot it over the first sequence of tokens as a mildly misleading reference).\n"," We can compare it to the interpretability in the wild diagram, and basically instantly recover most of the circuit!"]},{"cell_type":"code","execution_count":14,"metadata":{},"outputs":[],"source":["def create_attention_attr(clean_cache, clean_grad_cache) -> TT[\"batch\", \"layer\", \"head_index\", \"dest\", \"src\"]:\n"," attention_stack = torch.stack([clean_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0)\n"," attention_grad_stack = torch.stack([clean_grad_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0)\n"," attention_attr = attention_grad_stack * attention_stack\n"," attention_attr = einops.rearrange(attention_attr, \"layer batch head_index dest src -> batch layer head_index dest src\")\n"," return attention_attr\n","\n","attention_attr = create_attention_attr(clean_cache, clean_grad_cache)"]},{"cell_type":"code","execution_count":15,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["['L0H0', 'L0H1', 'L0H2', 'L0H3', 'L0H4']\n","['L0H0+', 'L0H0-', 'L0H1+', 'L0H1-', 'L0H2+']\n","['L0H0Q', 'L0H0K', 'L0H0V', 'L0H1Q', 'L0H1K']\n"]}],"source":["HEAD_NAMES = [f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]\n","HEAD_NAMES_SIGNED = [f\"{name}{sign}\" for name in HEAD_NAMES for sign in [\"+\", \"-\"]]\n","HEAD_NAMES_QKV = [f\"{name}{act_name}\" for name in HEAD_NAMES for act_name in [\"Q\", \"K\", \"V\"]]\n","print(HEAD_NAMES[:5])\n","print(HEAD_NAMES_SIGNED[:5])\n","print(HEAD_NAMES_QKV[:5])"]},{"cell_type":"markdown","metadata":{},"source":[" An extremely janky way to plot the attention attribution patterns. We scale them to be in [-1, 1], split each head into a positive and negative part (so all of it is in [0, 1]), and then plot the top 20 head-halves (a head can appear twice!) by the max value of the attribution pattern."]},{"cell_type":"code","execution_count":16,"metadata":{},"outputs":[{"data":{"text/markdown":["### Attention Attribution for first sequence"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["### Summed Attention Attribution for all sequences"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\n"]}],"source":["def plot_attention_attr(attention_attr, tokens, top_k=20, index=0, title=\"\"):\n"," if len(tokens.shape)==2:\n"," tokens = tokens[index]\n"," if len(attention_attr.shape)==5:\n"," attention_attr = attention_attr[index]\n"," attention_attr_pos = attention_attr.clamp(min=-1e-5)\n"," attention_attr_neg = - attention_attr.clamp(max=1e-5)\n"," attention_attr_signed = torch.stack([attention_attr_pos, attention_attr_neg], dim=0)\n"," attention_attr_signed = einops.rearrange(attention_attr_signed, \"sign layer head_index dest src -> (layer head_index sign) dest src\")\n"," attention_attr_signed = attention_attr_signed / attention_attr_signed.max()\n"," attention_attr_indices = attention_attr_signed.max(-1).values.max(-1).values.argsort(descending=True)\n"," # print(attention_attr_indices.shape)\n"," # print(attention_attr_indices)\n"," attention_attr_signed = attention_attr_signed[attention_attr_indices, :, :]\n"," head_labels = [HEAD_NAMES_SIGNED[i.item()] for i in attention_attr_indices]\n","\n"," if title: display(Markdown(\"### \"+title))\n"," display(pysvelte.AttentionMulti(tokens=model.to_str_tokens(tokens), attention=attention_attr_signed.permute(1, 2, 0)[:, :, :top_k], head_labels=head_labels[:top_k]))\n","\n","plot_attention_attr(attention_attr, clean_tokens, index=0, title=\"Attention Attribution for first sequence\")\n","\n","plot_attention_attr(attention_attr.sum(0), clean_tokens[0], title=\"Summed Attention Attribution for all sequences\")\n","print(\"Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Attribution Patching\n"," In the following sections, I will implement various kinds of attribution patching, and then compare them to the activation patching patterns (activation patching code copied from [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo))\n"," ### Residual Stream Patching\n","
Note: We add up across both d_model and batch (Explanation).\n"," We add up along d_model because we're taking the dot product - the derivative *is* the linear map that locally linearly approximates the metric, and so we take the dot product of our change vector with the derivative vector. Equivalent, we look at the effect of changing each coordinate independently, and then combine them by adding it up - it's linear, so this totally works.\n"," We add up across batch because we're taking the average of the metric, so each individual batch element provides `1/batch_size` of the overall effect. Because each batch element is independent of the others and no information moves between activations for different inputs, the batched version is equivalent to doing attribution patching separately for each input, and then averaging - in this second version the metric per input is *not* divided by batch_size because we don't average.
"]},{"cell_type":"code","execution_count":17,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_residual(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"pos\"]:\n"," clean_residual, residual_labels = clean_cache.accumulated_resid(-1, incl_mid=True, return_labels=True)\n"," corrupted_residual = corrupted_cache.accumulated_resid(-1, incl_mid=True, return_labels=False)\n"," corrupted_grad_residual = corrupted_grad_cache.accumulated_resid(-1, incl_mid=True, return_labels=False)\n"," residual_attr = einops.reduce(\n"," corrupted_grad_residual * (clean_residual - corrupted_residual),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\"\n"," )\n"," return residual_attr, residual_labels\n","\n","residual_attr, residual_labels = attr_patch_residual(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(residual_attr, y=residual_labels, yaxis=\"Component\", xaxis=\"Position\", title=\"Residual Attribution Patching\")\n","\n","# ### Layer Output Patching"]},{"cell_type":"code","execution_count":18,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_layer_out(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"pos\"]:\n"," clean_layer_out, labels = clean_cache.decompose_resid(-1, return_labels=True)\n"," corrupted_layer_out = corrupted_cache.decompose_resid(-1, return_labels=False)\n"," corrupted_grad_layer_out = corrupted_grad_cache.decompose_resid(-1, return_labels=False)\n"," layer_out_attr = einops.reduce(\n"," corrupted_grad_layer_out * (clean_layer_out - corrupted_layer_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\"\n"," )\n"," return layer_out_attr, labels\n","\n","layer_out_attr, layer_out_labels = attr_patch_layer_out(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(layer_out_attr, y=layer_out_labels, yaxis=\"Component\", xaxis=\"Position\", title=\"Layer Output Attribution Patching\")"]},{"cell_type":"code","execution_count":19,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_head_out(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_out = clean_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_head_out = corrupted_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_grad_head_out = corrupted_grad_cache.stack_head_results(-1, return_labels=False)\n"," head_out_attr = einops.reduce(\n"," corrupted_grad_head_out * (clean_head_out - corrupted_head_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\"\n"," )\n"," return head_out_attr, labels\n","\n","head_out_attr, head_out_labels = attr_patch_head_out(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(head_out_attr, y=head_out_labels, yaxis=\"Component\", xaxis=\"Position\", title=\"Head Output Attribution Patching\")\n","sum_head_out_attr = einops.reduce(head_out_attr, \"(layer head) pos -> layer head\", \"sum\", layer=model.cfg.n_layers, head=model.cfg.n_heads)\n","imshow(sum_head_out_attr, yaxis=\"Layer\", xaxis=\"Head Index\", title=\"Head Output Attribution Patching Sum Over Pos\")"]},{"cell_type":"markdown","metadata":{},"source":[" ### Head Activation Patching\n"," Intuitively, a head has three inputs, keys, queries and values. We can patch each of these individually to get a sense for where the important part of each head's input comes from!\n"," As a sanity check, we also do this for the mixed value. The result is a linear map of this (`z @ W_O == result`), so this is the same as patching the output of the head.\n"," We plot both the patch for each head over each position, and summed over position (it tends to be pretty sparse, so the latter is the same)"]},{"cell_type":"code","execution_count":20,"metadata":{},"outputs":[{"data":{"text/markdown":["#### Key Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Query Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Mixed Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","def stack_head_vector_from_cache(\n"," cache, \n"," activation_name: Literal[\"q\", \"k\", \"v\", \"z\"]\n"," ) -> TT[\"layer_and_head_index\", \"batch\", \"pos\", \"d_head\"]:\n"," \"\"\"Stacks the head vectors from the cache from a specific activation (key, query, value or mixed_value (z)) into a single tensor.\"\"\"\n"," stacked_head_vectors = torch.stack([cache[activation_name, l] for l in range(model.cfg.n_layers)], dim=0)\n"," stacked_head_vectors = einops.rearrange(\n"," stacked_head_vectors,\n"," \"layer batch pos head_index d_head -> (layer head_index) batch pos d_head\"\n"," )\n"," return stacked_head_vectors\n","\n","def attr_patch_head_vector(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," activation_name: Literal[\"q\", \"k\", \"v\", \"z\"],\n"," ) -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_vector = stack_head_vector_from_cache(clean_cache, activation_name)\n"," corrupted_head_vector = stack_head_vector_from_cache(corrupted_cache, activation_name)\n"," corrupted_grad_head_vector = stack_head_vector_from_cache(corrupted_grad_cache, activation_name)\n"," head_vector_attr = einops.reduce(\n"," corrupted_grad_head_vector * (clean_head_vector - corrupted_head_vector),\n"," \"component batch pos d_head -> component pos\",\n"," \"sum\"\n"," )\n"," return head_vector_attr, labels\n","\n","head_vector_attr_dict = {}\n","for activation_name, activation_name_full in [(\"k\", \"Key\"), (\"q\", \"Query\"), (\"v\", \"Value\"), (\"z\", \"Mixed Value\")]:\n"," display(Markdown(f\"#### {activation_name_full} Head Vector Attribution Patching\"))\n"," head_vector_attr_dict[activation_name], head_vector_labels = attr_patch_head_vector(clean_cache, corrupted_cache, corrupted_grad_cache, activation_name)\n"," imshow(head_vector_attr_dict[activation_name], y=head_vector_labels, yaxis=\"Component\", xaxis=\"Position\", title=f\"{activation_name_full} Attribution Patching\")\n"," sum_head_vector_attr = einops.reduce(head_vector_attr_dict[activation_name], \"(layer head) pos -> layer head\", \"sum\", layer=model.cfg.n_layers, head=model.cfg.n_heads)\n"," imshow(sum_head_vector_attr, yaxis=\"Layer\", xaxis=\"Head Index\", title=f\"{activation_name_full} Attribution Patching Sum Over Pos\")"]},{"cell_type":"code","execution_count":21,"metadata":{},"outputs":[{"data":{"text/markdown":["### Head Pattern Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","def stack_head_pattern_from_cache(\n"," cache, \n"," ) -> TT[\"layer_and_head_index\", \"batch\", \"dest_pos\", \"src_pos\"]:\n"," \"\"\"Stacks the head patterns from the cache into a single tensor.\"\"\"\n"," stacked_head_pattern = torch.stack([cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0)\n"," stacked_head_pattern = einops.rearrange(\n"," stacked_head_pattern,\n"," \"layer batch head_index dest_pos src_pos -> (layer head_index) batch dest_pos src_pos\"\n"," )\n"," return stacked_head_pattern\n","\n","def attr_patch_head_pattern(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"dest_pos\", \"src_pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_pattern = stack_head_pattern_from_cache(clean_cache)\n"," corrupted_head_pattern = stack_head_pattern_from_cache(corrupted_cache)\n"," corrupted_grad_head_pattern = stack_head_pattern_from_cache(corrupted_grad_cache)\n"," head_pattern_attr = einops.reduce(\n"," corrupted_grad_head_pattern * (clean_head_pattern - corrupted_head_pattern),\n"," \"component batch dest_pos src_pos -> component dest_pos src_pos\",\n"," \"sum\"\n"," )\n"," return head_pattern_attr, labels\n","\n","head_pattern_attr, labels = attr_patch_head_pattern(clean_cache, corrupted_cache, corrupted_grad_cache)\n","\n","plot_attention_attr(einops.rearrange(head_pattern_attr, \"(layer head) dest src -> layer head dest src\", layer=model.cfg.n_layers, head=model.cfg.n_heads), clean_tokens, index=0, title=\"Head Pattern Attribution Patching\")"]},{"cell_type":"code","execution_count":22,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_head_vector_grad_input_from_grad_cache(\n"," grad_cache: ActivationCache, \n"," activation_name: Literal[\"q\", \"k\", \"v\"],\n"," layer: int\n"," ) -> TT[\"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," vector_grad = grad_cache[activation_name, layer]\n"," ln_scales = grad_cache[\"scale\", layer, \"ln1\"]\n"," attn_layer_object = model.blocks[layer].attn\n"," if activation_name == \"q\":\n"," W = attn_layer_object.W_Q\n"," elif activation_name == \"k\":\n"," W = attn_layer_object.W_K\n"," elif activation_name == \"v\":\n"," W = attn_layer_object.W_V\n"," else:\n"," raise ValueError(\"Invalid activation name\")\n","\n"," return einsum(\"batch pos head_index d_head, batch pos, head_index d_model d_head -> batch pos head_index d_model\", vector_grad, ln_scales.squeeze(-1), W)\n","\n","def get_stacked_head_vector_grad_input(grad_cache, activation_name: Literal[\"q\", \"k\", \"v\"]) -> TT[\"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack([get_head_vector_grad_input_from_grad_cache(grad_cache, activation_name, l) for l in range(model.cfg.n_layers)], dim=0)\n","\n","def get_full_vector_grad_input(grad_cache) -> TT[\"qkv\", \"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack([get_stacked_head_vector_grad_input(grad_cache, activation_name) for activation_name in ['q', 'k', 'v']], dim=0)\n","\n","def attr_patch_head_path(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache\n"," ) -> TT[\"qkv\", \"dest_component\", \"src_component\", \"pos\"]:\n"," \"\"\"\n"," Computes the attribution patch along the path between each pair of heads.\n","\n"," Sets this to zero for the path from any late head to any early head\n","\n"," \"\"\"\n"," start_labels = HEAD_NAMES\n"," end_labels = HEAD_NAMES_QKV\n"," full_vector_grad_input = get_full_vector_grad_input(corrupted_grad_cache)\n"," clean_head_result_stack = clean_cache.stack_head_results(-1)\n"," corrupted_head_result_stack = corrupted_cache.stack_head_results(-1)\n"," diff_head_result = einops.rearrange(\n"," clean_head_result_stack - corrupted_head_result_stack,\n"," \"(layer head_index) batch pos d_model -> layer batch pos head_index d_model\",\n"," layer = model.cfg.n_layers,\n"," head_index = model.cfg.n_heads,\n"," )\n"," path_attr = einsum(\n"," \"qkv layer_end batch pos head_end d_model, layer_start batch pos head_start d_model -> qkv layer_end head_end layer_start head_start pos\", \n"," full_vector_grad_input, \n"," diff_head_result)\n"," correct_layer_order_mask = (\n"," torch.arange(model.cfg.n_layers)[None, :, None, None, None, None] > \n"," torch.arange(model.cfg.n_layers)[None, None, None, :, None, None]).to(path_attr.device)\n"," zero = torch.zeros(1, device=path_attr.device)\n"," path_attr = torch.where(correct_layer_order_mask, path_attr, zero)\n","\n"," path_attr = einops.rearrange(\n"," path_attr,\n"," \"qkv layer_end head_end layer_start head_start pos -> (layer_end head_end qkv) (layer_start head_start) pos\",\n"," )\n"," return path_attr, end_labels, start_labels\n","\n","head_path_attr, end_labels, start_labels = attr_patch_head_path(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(head_path_attr.sum(-1), y=end_labels, yaxis=\"Path End (Head Input)\", x=start_labels, xaxis=\"Path Start (Head Output)\", title=\"Head Path Attribution Patching\")"]},{"cell_type":"markdown","metadata":{},"source":[" This is hard to parse. Here's an experiment with filtering for the most important heads and showing their paths."]},{"cell_type":"code","execution_count":23,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["head_out_values, head_out_indices = head_out_attr.sum(-1).abs().sort(descending=True)\n","line(head_out_values)\n","top_head_indices = head_out_indices[:22].sort().values\n","top_end_indices = []\n","top_end_labels = []\n","top_start_indices = []\n","top_start_labels = []\n","for i in top_head_indices:\n"," i = i.item()\n"," top_start_indices.append(i)\n"," top_start_labels.append(start_labels[i])\n"," for j in range(3):\n"," top_end_indices.append(3*i+j)\n"," top_end_labels.append(end_labels[3*i+j])\n","\n","imshow(head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1), y=top_end_labels, yaxis=\"Path End (Head Input)\", x=top_start_labels, xaxis=\"Path Start (Head Output)\", title=\"Head Path Attribution Patching (Filtered for Top Heads)\")"]},{"cell_type":"code","execution_count":24,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["for j, composition_type in enumerate([\"Query\", \"Key\", \"Value\"]):\n"," imshow(head_path_attr[top_end_indices, :][:, top_start_indices][j::3].sum(-1), y=top_end_labels[j::3], yaxis=\"Path End (Head Input)\", x=top_start_labels, xaxis=\"Path Start (Head Output)\", title=f\"Head Path to {composition_type} Attribution Patching (Filtered for Top Heads)\")"]},{"cell_type":"code","execution_count":25,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["top_head_path_attr = einops.rearrange(head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1), \"(head_end qkv) head_start -> qkv head_end head_start\", qkv=3)\n","imshow(top_head_path_attr, y=[i[:-1] for i in top_end_labels[::3]], yaxis=\"Path End (Head Input)\", x=top_start_labels, xaxis=\"Path Start (Head Output)\", title=f\"Head Path Attribution Patching (Filtered for Top Heads)\", facet_col=0, facet_labels=[\"Query\", \"Key\", \"Value\"])"]},{"cell_type":"markdown","metadata":{},"source":[" Let's now dive into 3 interesting heads: L5H5 (induction head), L8H6 (S-Inhibition Head), L9H9 (Name Mover) and look at their input and output paths (note - Q input means )"]},{"cell_type":"code","execution_count":26,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["interesting_heads = [5 * model.cfg.n_heads + 5, 8 * model.cfg.n_heads + 6, 9 * model.cfg.n_heads + 9]\n","interesting_head_labels = [HEAD_NAMES[i] for i in interesting_heads]\n","for head_index, label in zip(interesting_heads, interesting_head_labels):\n"," in_paths = head_path_attr[3*head_index:3*head_index+3].sum(-1)\n"," out_paths = head_path_attr[:, head_index].sum(-1)\n"," out_paths = einops.rearrange(out_paths, \"(layer_head qkv) -> qkv layer_head\", qkv=3)\n"," all_paths = torch.cat([in_paths, out_paths], dim=0)\n"," all_paths = einops.rearrange(all_paths, \"path_type (layer head) -> path_type layer head\", layer=model.cfg.n_layers, head=model.cfg.n_heads)\n"," imshow(all_paths, facet_col=0, facet_labels=[\"Query (In)\", \"Key (In)\", \"Value (In)\", \"Query (Out)\", \"Key (Out)\", \"Value (Out)\"], title=f\"Input and Output Paths for head {label}\", yaxis=\"Layer\", xaxis=\"Head\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Validating Attribution vs Activation Patching\n"," Let's now compare attribution and activation patching. Generally it's a decent approximation! The main place it fails is MLP0 and the residual stream\n"," My fuzzy intuition is that attribution patching works badly for \"big\" things which are poorly modelled as linear approximations, and works well for \"small\" things which are more like incremental changes. Anything involving replacing the embedding is a \"big\" thing, which includes residual streams, and in GPT-2 small MLP0 seems to be used as an \"extended embedding\" (where later layers use MLP0's output instead of the token embedding), so I also count it as big.\n"," See more discussion in the accompanying blog post!\n"]},{"cell_type":"markdown","metadata":{},"source":[" First do some refactoring to make attribution patching more generic. We make an attribution cache, which is an ActivationCache where each element is (clean_act - corrupted_act) * corrupted_grad, so that it's the per-element attribution for each activation. Thanks to linearity, we just compute things by adding stuff up along the relevant dimensions!"]},{"cell_type":"code","execution_count":27,"metadata":{},"outputs":[],"source":["attribution_cache_dict = {}\n","for key in corrupted_grad_cache.cache_dict.keys():\n"," attribution_cache_dict[key] = corrupted_grad_cache.cache_dict[key] * (clean_cache.cache_dict[key] - corrupted_cache.cache_dict[key])\n","attr_cache = ActivationCache(attribution_cache_dict, model)"]},{"cell_type":"markdown","metadata":{},"source":[" By block: For each head we patch the starting residual stream, attention output + MLP output"]},{"cell_type":"code","execution_count":28,"metadata":{},"outputs":[],"source":["str_tokens = model.to_str_tokens(clean_tokens[0])\n","context_length = len(str_tokens)"]},{"cell_type":"code","execution_count":29,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"95a5290e11b64b6a95ef5dd37d027c7a","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_block_act_patch_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(every_block_act_patch_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Activation Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])"]},{"cell_type":"code","execution_count":30,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_block_every(attr_cache):\n"," resid_pre_attr = einops.reduce(\n"," attr_cache.stack_activation(\"resid_pre\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," attn_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"attn_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," mlp_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"mlp_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n","\n"," every_block_attr_patch_result = torch.stack([resid_pre_attr, attn_out_attr, mlp_out_attr], dim=0)\n"," return every_block_attr_patch_result\n","every_block_attr_patch_result = get_attr_patch_block_every(attr_cache)\n","imshow(every_block_attr_patch_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Attribution Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])"]},{"cell_type":"code","execution_count":31,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(y=every_block_attr_patch_result.reshape(3, -1), x=every_block_act_patch_result.reshape(3, -1), facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Attribution vs Activation Patching Per Block\", xaxis=\"Activation Patch\", yaxis=\"Attribution Patch\", hover=[f\"Layer {l}, Position {p}, |{str_tokens[p]}|\" for l in range(model.cfg.n_layers) for p in range(context_length)], color=einops.repeat(torch.arange(model.cfg.n_layers), \"layer -> (layer pos)\", pos=context_length), color_continuous_scale=\"Portland\")"]},{"cell_type":"markdown","metadata":{},"source":[" By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow."]},{"cell_type":"code","execution_count":32,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"18b2e6b0985b40cd8c0cd1a16ba62975","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=1, zmin=-1)"]},{"cell_type":"code","execution_count":33,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_all_pos_every(attr_cache):\n"," head_out_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_q_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_k_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_v_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack([head_out_all_pos_attr, head_q_all_pos_attr, head_k_all_pos_attr, head_v_all_pos_attr, head_pattern_all_pos_attr])\n"," \n","every_head_all_pos_attr_patch_result = get_attr_patch_attn_head_all_pos_every(attr_cache)\n","imshow(every_head_all_pos_attr_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=1, zmin=-1)"]},{"cell_type":"code","execution_count":34,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(y=every_head_all_pos_attr_patch_result.reshape(5, -1), x=every_head_all_pos_act_patch_result.reshape(5, -1), facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution vs Activation Patching Per Head (All Pos)\", xaxis=\"Activation Patch\", yaxis=\"Attribution Patch\", include_diag=True, hover=head_out_labels, color=einops.repeat(torch.arange(model.cfg.n_layers), \"layer -> (layer head)\", head=model.cfg.n_heads), color_continuous_scale=\"Portland\")"]},{"cell_type":"markdown","metadata":{},"source":[" We see pretty good results in general, but significant errors for heads L5H5 on query and moderate errors for head L10H7 on query and key, and moderate errors for head L11H10 on key. But each of these is fine for pattern and output. My guess is that the problem is that these have pretty saturated attention on a single token, and the linear approximation is thus not great on the attention calculation here, but I'm not sure. When we plot the attention patterns, we do see this!\n"," Note that the axis labels are for the *first* prompt's tokens, but each facet is a different prompt, so this is somewhat inaccurate. In particular, every odd facet has indirect object and subject in the opposite order (IO first). But otherwise everything lines up between the prompts"]},{"cell_type":"code","execution_count":35,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["graph_tok_labels = [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]\n","imshow(clean_cache[\"pattern\", 5][:, 5], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title=\"Attention for Head L5H5\", facet_name=\"Prompt\")\n","imshow(clean_cache[\"pattern\", 10][:, 7], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title=\"Attention for Head L10H7\", facet_name=\"Prompt\")\n","imshow(clean_cache[\"pattern\", 11][:, 10], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title=\"Attention for Head L11H10\", facet_name=\"Prompt\")\n","\n","\n","# [markdown]"]},{"cell_type":"code","execution_count":36,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"06f39489001845849fbc7446a07066f4","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/2160 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_by_pos_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","every_head_by_pos_act_patch_result = einops.rearrange(every_head_by_pos_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n","imshow(every_head_by_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=head_out_labels)"]},{"cell_type":"code","execution_count":37,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_by_pos_every(attr_cache):\n"," head_out_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_q_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_k_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_v_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer dest_pos head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack([head_out_by_pos_attr, head_q_by_pos_attr, head_k_by_pos_attr, head_v_by_pos_attr, head_pattern_by_pos_attr])\n","every_head_by_pos_attr_patch_result = get_attr_patch_attn_head_by_pos_every(attr_cache)\n","every_head_by_pos_attr_patch_result = einops.rearrange(every_head_by_pos_attr_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n","imshow(every_head_by_pos_attr_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=head_out_labels)"]},{"cell_type":"code","execution_count":38,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(y=every_head_by_pos_attr_patch_result.reshape(5, -1), x=every_head_by_pos_act_patch_result.reshape(5, -1), facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution vs Activation Patching Per Head (by Pos)\", xaxis=\"Activation Patch\", yaxis=\"Attribution Patch\", include_diag=True, hover=[f\"{label} {tok}\" for label in head_out_labels for tok in graph_tok_labels], color=einops.repeat(torch.arange(model.cfg.n_layers), \"layer -> (layer head pos)\", head=model.cfg.n_heads, pos = 15), color_continuous_scale=\"Portland\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Factual Knowledge Patching Example\n"," Incomplete, but maybe of interest!\n"," Note that I have better results with the corrupted prompt as having random words rather than Colosseum."]},{"cell_type":"code","execution_count":39,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-xl into HookedTransformer\n","Tokenized prompt: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Paris']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.73 Prob: 95.80% Token: | Paris|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.73\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m95.80\u001b[0m\u001b[1m% Token: | Paris|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.73 Prob: 95.80% Token: | Paris|\n","Top 1th token. Logit: 16.49 Prob: 1.39% Token: | E|\n","Top 2th token. Logit: 14.69 Prob: 0.23% Token: | the|\n","Top 3th token. Logit: 14.58 Prob: 0.21% Token: | É|\n","Top 4th token. Logit: 14.44 Prob: 0.18% Token: | France|\n","Top 5th token. Logit: 14.36 Prob: 0.16% Token: | Mont|\n","Top 6th token. Logit: 13.77 Prob: 0.09% Token: | Le|\n","Top 7th token. Logit: 13.66 Prob: 0.08% Token: | Ang|\n","Top 8th token. Logit: 13.43 Prob: 0.06% Token: | V|\n","Top 9th token. Logit: 13.42 Prob: 0.06% Token: | Stras|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Paris', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Paris'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Tokenized prompt: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Rome']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.02 Prob: 83.70% Token: | Rome|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.02\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m83.70\u001b[0m\u001b[1m% Token: | Rome|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.02 Prob: 83.70% Token: | Rome|\n","Top 1th token. Logit: 17.03 Prob: 4.23% Token: | Naples|\n","Top 2th token. Logit: 16.85 Prob: 3.51% Token: | Pompe|\n","Top 3th token. Logit: 16.14 Prob: 1.73% Token: | Ver|\n","Top 4th token. Logit: 15.87 Prob: 1.32% Token: | Florence|\n","Top 5th token. Logit: 14.77 Prob: 0.44% Token: | Roma|\n","Top 6th token. Logit: 14.68 Prob: 0.40% Token: | Milan|\n","Top 7th token. Logit: 14.66 Prob: 0.39% Token: | ancient|\n","Top 8th token. Logit: 14.37 Prob: 0.29% Token: | Pal|\n","Top 9th token. Logit: 14.30 Prob: 0.27% Token: | Constantinople|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Rome', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Rome'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"}],"source":["gpt2_xl = HookedTransformer.from_pretrained(\"gpt2-xl\")\n","clean_prompt = \"The Eiffel Tower is located in the city of\"\n","clean_answer = \" Paris\"\n","# corrupted_prompt = \"The red brown fox jumps is located in the city of\"\n","corrupted_prompt = \"The Colosseum is located in the city of\"\n","corrupted_answer = \" Rome\"\n","utils.test_prompt(clean_prompt, clean_answer, gpt2_xl)\n","utils.test_prompt(corrupted_prompt, corrupted_answer, gpt2_xl)"]},{"cell_type":"code","execution_count":40,"metadata":{},"outputs":[],"source":["clean_answer_index = gpt2_xl.to_single_token(clean_answer)\n","corrupted_answer_index = gpt2_xl.to_single_token(corrupted_answer)\n","def factual_logit_diff(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return logits[0, -1, clean_answer_index] - logits[0, -1, corrupted_answer_index]"]},{"cell_type":"code","execution_count":41,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 10.634519577026367\n","Corrupted logit diff: -8.988396644592285\n","Clean Metric: tensor(1., device='cuda:0', grad_fn=)\n","Corrupted Metric: tensor(0., device='cuda:0', grad_fn=)\n"]}],"source":["clean_logits, clean_cache = gpt2_xl.run_with_cache(clean_prompt)\n","CLEAN_LOGIT_DIFF_FACTUAL = factual_logit_diff(clean_logits).item()\n","corrupted_logits, _ = gpt2_xl.run_with_cache(corrupted_prompt)\n","CORRUPTED_LOGIT_DIFF_FACTUAL = factual_logit_diff(corrupted_logits).item()\n","\n","def factual_metric(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return (factual_logit_diff(logits) - CORRUPTED_LOGIT_DIFF_FACTUAL) / (CLEAN_LOGIT_DIFF_FACTUAL - CORRUPTED_LOGIT_DIFF_FACTUAL)\n","print(\"Clean logit diff:\", CLEAN_LOGIT_DIFF_FACTUAL)\n","print(\"Corrupted logit diff:\", CORRUPTED_LOGIT_DIFF_FACTUAL)\n","print(\"Clean Metric:\", factual_metric(clean_logits))\n","print(\"Corrupted Metric:\", factual_metric(corrupted_logits))"]},{"cell_type":"code","execution_count":42,"metadata":{},"outputs":[],"source":["# corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(gpt2_xl, corrupted_prompt, factual_metric)"]},{"cell_type":"code","execution_count":43,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Corrupted: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n"]}],"source":["clean_tokens = gpt2_xl.to_tokens(clean_prompt)\n","clean_str_tokens = gpt2_xl.to_str_tokens(clean_prompt)\n","corrupted_tokens = gpt2_xl.to_tokens(corrupted_prompt)\n","corrupted_str_tokens = gpt2_xl.to_str_tokens(corrupted_prompt)\n","print(\"Clean:\", clean_str_tokens)\n","print(\"Corrupted:\", corrupted_str_tokens)"]},{"cell_type":"code","execution_count":44,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"b767eef7a3cd49b9b3cb6e5301463f08","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/48 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def act_patch_residual(clean_cache, corrupted_tokens, model: HookedTransformer, metric):\n"," if len(corrupted_tokens.shape)==2:\n"," corrupted_tokens = corrupted_tokens[0]\n"," residual_patches = torch.zeros((model.cfg.n_layers, len(corrupted_tokens)), device=model.cfg.device)\n"," def residual_hook(resid_pre, hook, layer, pos):\n"," resid_pre[:, pos, :] = clean_cache[\"resid_pre\", layer][:, pos, :]\n"," return resid_pre\n"," for layer in tqdm.tqdm(range(model.cfg.n_layers)):\n"," for pos in range(len(corrupted_tokens)):\n"," patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[(f\"blocks.{layer}.hook_resid_pre\", partial(residual_hook, layer=layer, pos=pos))])\n"," residual_patches[layer, pos] = metric(patched_logits).item()\n"," return residual_patches\n","\n","residual_act_patch = act_patch_residual(clean_cache, corrupted_tokens, gpt2_xl, factual_metric)\n","\n","imshow(residual_act_patch, title=\"Factual Recall Patching (Residual)\", xaxis=\"Position\", yaxis=\"Layer\", x=clean_str_tokens)"]}],"metadata":{"kernelspec":{"display_name":"base","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.7.13"},"orig_nbformat":4,"vscode":{"interpreter":{"hash":"d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"}}},"nbformat":4,"nbformat_minor":2} +{"cells":[{"cell_type":"markdown","metadata":{},"source":["\n"," \"Open\n",""]},{"cell_type":"markdown","metadata":{},"source":[" # Attribution Patching Demo\n"," **Read [the accompanying blog post here](https://neelnanda.io/attribution-patching) for more context**\n"," This is an interim research report, giving a whirlwind tour of some unpublished work I did at Anthropic (credit to the then team - Chris Olah, Catherine Olsson, Nelson Elhage and Tristan Hume for help, support, and mentorship!)\n","\n"," The goal of this work is run activation patching at an industrial scale, by using gradient based attribution to approximate the technique - allow an arbitrary number of patches to be made on two forwards and a single backward pass\n","\n"," I have had less time than hoped to flesh out this investigation, but am writing up a rough investigation and comparison to standard activation patching on a few tasks to give a sense of the potential of this approach, and where it works vs falls down."]},{"cell_type":"markdown","metadata":{},"source":[" To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n","\n"," **Tips for reading this Colab:**\n"," * You can run all this code for yourself!\n"," * The graphs are interactive!\n"," * Use the table of contents pane in the sidebar to navigate\n"," * Collapse irrelevant sections with the dropdown arrows\n"," * Search the page using the search in the sidebar, not CTRL+F"]},{"cell_type":"markdown","metadata":{},"source":[" ## Setup (Ignore)"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Running as a Jupyter notebook - intended for development only!\n"]}],"source":["# Janky code to do different setup when run in a Colab notebook vs VSCode\n","DEBUG_MODE = False\n","try:\n"," import google.colab\n"," IN_COLAB = True\n"," print(\"Running as a Colab notebook\")\n"," %pip install transformer_lens==1.1.1\n"," %pip install torchtyping\n"," # Install my janky personal plotting utils\n"," %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n"," # Install another version of node that makes PySvelte work way faster\n"," !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n"," %pip install git+https://github.com/neelnanda-io/PySvelte.git\n"," # Needed for PySvelte to work, v3 came out and broke things...\n"," %pip install typeguard==2.13.3\n","except:\n"," IN_COLAB = False\n"," print(\"Running as a Jupyter notebook - intended for development only!\")\n"," from IPython import get_ipython\n","\n"," ipython = get_ipython()\n"," # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n"," ipython.magic(\"load_ext autoreload\")\n"," ipython.magic(\"autoreload 2\")"]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[],"source":["# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n","import plotly.io as pio\n","\n","if IN_COLAB or not DEBUG_MODE:\n"," # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n"," pio.renderers.default = \"colab\"\n","else:\n"," pio.renderers.default = \"notebook_connected\""]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[],"source":["# Import stuff\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","import numpy as np\n","import einops\n","from fancy_einsum import einsum\n","import tqdm.notebook as tqdm\n","import random\n","from pathlib import Path\n","import plotly.express as px\n","from torch.utils.data import DataLoader\n","\n","from torchtyping import TensorType as TT\n","from typing import List, Union, Optional, Callable\n","from functools import partial\n","import copy\n","import itertools\n","import json\n","\n","from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n","import dataclasses\n","import datasets\n","from IPython.display import HTML, Markdown"]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[],"source":["import pysvelte\n","\n","import transformer_lens\n","import transformer_lens.utils as utils\n","from transformer_lens.hook_points import (\n"," HookedRootModule,\n"," HookPoint,\n",") # Hooking utilities\n","from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache"]},{"cell_type":"markdown","metadata":{},"source":[" Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[],"source":["from neel_plotly import line, imshow, scatter"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":["import transformer_lens.patching as patching"]},{"cell_type":"markdown","metadata":{},"source":[" ## IOI Patching Setup\n"," This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important."]},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-small into HookedTransformer\n"]}],"source":["model = HookedTransformer.from_pretrained(\"gpt2-small\")\n","model.set_use_attn_result(True)"]},{"cell_type":"code","execution_count":9,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to\n","Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to\n","Answer token indices tensor([[ 5335, 1757],\n"," [ 1757, 5335],\n"," [ 4186, 3700],\n"," [ 3700, 4186],\n"," [ 6035, 15686],\n"," [15686, 6035],\n"," [ 5780, 14235],\n"," [14235, 5780]], device='cuda:0')\n"]}],"source":["prompts = ['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']\n","answers = [(' Mary', ' John'), (' John', ' Mary'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]\n","\n","clean_tokens = model.to_tokens(prompts)\n","# Swap each adjacent pair, with a hacky list comprehension\n","corrupted_tokens = clean_tokens[\n"," [(i+1 if i%2==0 else i-1) for i in range(len(clean_tokens)) ]\n"," ]\n","print(\"Clean string 0\", model.to_string(clean_tokens[0]))\n","print(\"Corrupted string 0\", model.to_string(corrupted_tokens[0]))\n","\n","answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=model.cfg.device)\n","print(\"Answer token indices\", answer_token_indices)"]},{"cell_type":"code","execution_count":10,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 3.5519\n","Corrupted logit diff: -3.5519\n"]}],"source":["def get_logit_diff(logits, answer_token_indices=answer_token_indices):\n"," if len(logits.shape)==3:\n"," # Get final logits only\n"," logits = logits[:, -1, :]\n"," correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))\n"," incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))\n"," return (correct_logits - incorrect_logits).mean()\n","\n","clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n","corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n","\n","clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()\n","print(f\"Clean logit diff: {clean_logit_diff:.4f}\")\n","\n","corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()\n","print(f\"Corrupted logit diff: {corrupted_logit_diff:.4f}\")"]},{"cell_type":"code","execution_count":11,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Baseline is 1: 1.0000\n","Corrupted Baseline is 0: 0.0000\n"]}],"source":["CLEAN_BASELINE = clean_logit_diff\n","CORRUPTED_BASELINE = corrupted_logit_diff\n","def ioi_metric(logits, answer_token_indices=answer_token_indices):\n"," return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE - CORRUPTED_BASELINE)\n","\n","print(f\"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}\")\n","print(f\"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Patching\n"," In the following cells, we define attribution patching and use it in various ways on the model."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["Metric = Callable[[TT[\"batch_and_pos_dims\", \"d_model\"]], float]"]},{"cell_type":"code","execution_count":13,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Value: 1.0\n","Clean Activations Cached: 220\n","Clean Gradients Cached: 220\n","Corrupted Value: 0.0\n","Corrupted Activations Cached: 220\n","Corrupted Gradients Cached: 220\n"]}],"source":["filter_not_qkv_input = lambda name: \"_input\" not in name\n","def get_cache_fwd_and_bwd(model, tokens, metric):\n"," model.reset_hooks()\n"," cache = {}\n"," def forward_cache_hook(act, hook):\n"," cache[hook.name] = act.detach()\n"," model.add_hook(filter_not_qkv_input, forward_cache_hook, \"fwd\")\n","\n"," grad_cache = {}\n"," def backward_cache_hook(act, hook):\n"," grad_cache[hook.name] = act.detach()\n"," model.add_hook(filter_not_qkv_input, backward_cache_hook, \"bwd\")\n","\n"," value = metric(model(tokens))\n"," value.backward()\n"," model.reset_hooks()\n"," return value.item(), ActivationCache(cache, model), ActivationCache(grad_cache, model)\n","\n","clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(model, clean_tokens, ioi_metric)\n","print(\"Clean Value:\", clean_value)\n","print(\"Clean Activations Cached:\", len(clean_cache))\n","print(\"Clean Gradients Cached:\", len(clean_grad_cache))\n","corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(model, corrupted_tokens, ioi_metric)\n","print(\"Corrupted Value:\", corrupted_value)\n","print(\"Corrupted Activations Cached:\", len(corrupted_cache))\n","print(\"Corrupted Gradients Cached:\", len(corrupted_grad_cache))"]},{"cell_type":"markdown","metadata":{},"source":[" ### Attention Attribution\n"," The easiest thing to start with is to not even engage with the corrupted tokens/patching, but to look at the attribution of the attention patterns - that is, the linear approximation to what happens if you set each element of the attention pattern to zero. This, as it turns out, is a good proxy to what is going on with each head!\n"," Note that this is *not* the same as what we will later do with patching. In particular, this does not set up a careful counterfactual! It's a good tool for what's generally going on in this problem, but does not control for eg stuff that systematically boosts John > Mary in general, stuff that says \"I should activate the IOI circuit\", etc. Though using logit diff as our metric *does*\n"," Each element of the batch is independent and the metric is an average logit diff, so we can analyse each batch element independently here. We'll look at the first one, and then at the average across the whole batch (note - 4 prompts have indirect object before subject, 4 prompts have it the other way round, making the average pattern harder to interpret - I plot it over the first sequence of tokens as a mildly misleading reference).\n"," We can compare it to the interpretability in the wild diagram, and basically instantly recover most of the circuit!"]},{"cell_type":"code","execution_count":14,"metadata":{},"outputs":[],"source":["def create_attention_attr(clean_cache, clean_grad_cache) -> TT[\"batch\", \"layer\", \"head_index\", \"dest\", \"src\"]:\n"," attention_stack = torch.stack([clean_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0)\n"," attention_grad_stack = torch.stack([clean_grad_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0)\n"," attention_attr = attention_grad_stack * attention_stack\n"," attention_attr = einops.rearrange(attention_attr, \"layer batch head_index dest src -> batch layer head_index dest src\")\n"," return attention_attr\n","\n","attention_attr = create_attention_attr(clean_cache, clean_grad_cache)"]},{"cell_type":"code","execution_count":15,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["['L0H0', 'L0H1', 'L0H2', 'L0H3', 'L0H4']\n","['L0H0+', 'L0H0-', 'L0H1+', 'L0H1-', 'L0H2+']\n","['L0H0Q', 'L0H0K', 'L0H0V', 'L0H1Q', 'L0H1K']\n"]}],"source":["HEAD_NAMES = [f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]\n","HEAD_NAMES_SIGNED = [f\"{name}{sign}\" for name in HEAD_NAMES for sign in [\"+\", \"-\"]]\n","HEAD_NAMES_QKV = [f\"{name}{act_name}\" for name in HEAD_NAMES for act_name in [\"Q\", \"K\", \"V\"]]\n","print(HEAD_NAMES[:5])\n","print(HEAD_NAMES_SIGNED[:5])\n","print(HEAD_NAMES_QKV[:5])"]},{"cell_type":"markdown","metadata":{},"source":[" An extremely janky way to plot the attention attribution patterns. We scale them to be in [-1, 1], split each head into a positive and negative part (so all of it is in [0, 1]), and then plot the top 20 head-halves (a head can appear twice!) by the max value of the attribution pattern."]},{"cell_type":"code","execution_count":16,"metadata":{},"outputs":[{"data":{"text/markdown":["### Attention Attribution for first sequence"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["### Summed Attention Attribution for all sequences"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\n"]}],"source":["def plot_attention_attr(attention_attr, tokens, top_k=20, index=0, title=\"\"):\n"," if len(tokens.shape)==2:\n"," tokens = tokens[index]\n"," if len(attention_attr.shape)==5:\n"," attention_attr = attention_attr[index]\n"," attention_attr_pos = attention_attr.clamp(min=-1e-5)\n"," attention_attr_neg = - attention_attr.clamp(max=1e-5)\n"," attention_attr_signed = torch.stack([attention_attr_pos, attention_attr_neg], dim=0)\n"," attention_attr_signed = einops.rearrange(attention_attr_signed, \"sign layer head_index dest src -> (layer head_index sign) dest src\")\n"," attention_attr_signed = attention_attr_signed / attention_attr_signed.max()\n"," attention_attr_indices = attention_attr_signed.max(-1).values.max(-1).values.argsort(descending=True)\n"," # print(attention_attr_indices.shape)\n"," # print(attention_attr_indices)\n"," attention_attr_signed = attention_attr_signed[attention_attr_indices, :, :]\n"," head_labels = [HEAD_NAMES_SIGNED[i.item()] for i in attention_attr_indices]\n","\n"," if title: display(Markdown(\"### \"+title))\n"," display(pysvelte.AttentionMulti(tokens=model.to_str_tokens(tokens), attention=attention_attr_signed.permute(1, 2, 0)[:, :, :top_k], head_labels=head_labels[:top_k]))\n","\n","plot_attention_attr(attention_attr, clean_tokens, index=0, title=\"Attention Attribution for first sequence\")\n","\n","plot_attention_attr(attention_attr.sum(0), clean_tokens[0], title=\"Summed Attention Attribution for all sequences\")\n","print(\"Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Attribution Patching\n"," In the following sections, I will implement various kinds of attribution patching, and then compare them to the activation patching patterns (activation patching code copied from [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo))\n"," ### Residual Stream Patching\n","
Note: We add up across both d_model and batch (Explanation).\n"," We add up along d_model because we're taking the dot product - the derivative *is* the linear map that locally linearly approximates the metric, and so we take the dot product of our change vector with the derivative vector. Equivalent, we look at the effect of changing each coordinate independently, and then combine them by adding it up - it's linear, so this totally works.\n"," We add up across batch because we're taking the average of the metric, so each individual batch element provides `1/batch_size` of the overall effect. Because each batch element is independent of the others and no information moves between activations for different inputs, the batched version is equivalent to doing attribution patching separately for each input, and then averaging - in this second version the metric per input is *not* divided by batch_size because we don't average.
"]},{"cell_type":"code","execution_count":17,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_residual(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"pos\"]:\n"," clean_residual, residual_labels = clean_cache.accumulated_resid(-1, incl_mid=True, return_labels=True)\n"," corrupted_residual = corrupted_cache.accumulated_resid(-1, incl_mid=True, return_labels=False)\n"," corrupted_grad_residual = corrupted_grad_cache.accumulated_resid(-1, incl_mid=True, return_labels=False)\n"," residual_attr = einops.reduce(\n"," corrupted_grad_residual * (clean_residual - corrupted_residual),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\"\n"," )\n"," return residual_attr, residual_labels\n","\n","residual_attr, residual_labels = attr_patch_residual(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(residual_attr, y=residual_labels, yaxis=\"Component\", xaxis=\"Position\", title=\"Residual Attribution Patching\")\n","\n","# ### Layer Output Patching"]},{"cell_type":"code","execution_count":18,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_layer_out(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"pos\"]:\n"," clean_layer_out, labels = clean_cache.decompose_resid(-1, return_labels=True)\n"," corrupted_layer_out = corrupted_cache.decompose_resid(-1, return_labels=False)\n"," corrupted_grad_layer_out = corrupted_grad_cache.decompose_resid(-1, return_labels=False)\n"," layer_out_attr = einops.reduce(\n"," corrupted_grad_layer_out * (clean_layer_out - corrupted_layer_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\"\n"," )\n"," return layer_out_attr, labels\n","\n","layer_out_attr, layer_out_labels = attr_patch_layer_out(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(layer_out_attr, y=layer_out_labels, yaxis=\"Component\", xaxis=\"Position\", title=\"Layer Output Attribution Patching\")"]},{"cell_type":"code","execution_count":19,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_head_out(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_out = clean_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_head_out = corrupted_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_grad_head_out = corrupted_grad_cache.stack_head_results(-1, return_labels=False)\n"," head_out_attr = einops.reduce(\n"," corrupted_grad_head_out * (clean_head_out - corrupted_head_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\"\n"," )\n"," return head_out_attr, labels\n","\n","head_out_attr, head_out_labels = attr_patch_head_out(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(head_out_attr, y=head_out_labels, yaxis=\"Component\", xaxis=\"Position\", title=\"Head Output Attribution Patching\")\n","sum_head_out_attr = einops.reduce(head_out_attr, \"(layer head) pos -> layer head\", \"sum\", layer=model.cfg.n_layers, head=model.cfg.n_heads)\n","imshow(sum_head_out_attr, yaxis=\"Layer\", xaxis=\"Head Index\", title=\"Head Output Attribution Patching Sum Over Pos\")"]},{"cell_type":"markdown","metadata":{},"source":[" ### Head Activation Patching\n"," Intuitively, a head has three inputs, keys, queries and values. We can patch each of these individually to get a sense for where the important part of each head's input comes from!\n"," As a sanity check, we also do this for the mixed value. The result is a linear map of this (`z @ W_O == result`), so this is the same as patching the output of the head.\n"," We plot both the patch for each head over each position, and summed over position (it tends to be pretty sparse, so the latter is the same)"]},{"cell_type":"code","execution_count":20,"metadata":{},"outputs":[{"data":{"text/markdown":["#### Key Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Query Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Mixed Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","def stack_head_vector_from_cache(\n"," cache, \n"," activation_name: Literal[\"q\", \"k\", \"v\", \"z\"]\n"," ) -> TT[\"layer_and_head_index\", \"batch\", \"pos\", \"d_head\"]:\n"," \"\"\"Stacks the head vectors from the cache from a specific activation (key, query, value or mixed_value (z)) into a single tensor.\"\"\"\n"," stacked_head_vectors = torch.stack([cache[activation_name, l] for l in range(model.cfg.n_layers)], dim=0)\n"," stacked_head_vectors = einops.rearrange(\n"," stacked_head_vectors,\n"," \"layer batch pos head_index d_head -> (layer head_index) batch pos d_head\"\n"," )\n"," return stacked_head_vectors\n","\n","def attr_patch_head_vector(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," activation_name: Literal[\"q\", \"k\", \"v\", \"z\"],\n"," ) -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_vector = stack_head_vector_from_cache(clean_cache, activation_name)\n"," corrupted_head_vector = stack_head_vector_from_cache(corrupted_cache, activation_name)\n"," corrupted_grad_head_vector = stack_head_vector_from_cache(corrupted_grad_cache, activation_name)\n"," head_vector_attr = einops.reduce(\n"," corrupted_grad_head_vector * (clean_head_vector - corrupted_head_vector),\n"," \"component batch pos d_head -> component pos\",\n"," \"sum\"\n"," )\n"," return head_vector_attr, labels\n","\n","head_vector_attr_dict = {}\n","for activation_name, activation_name_full in [(\"k\", \"Key\"), (\"q\", \"Query\"), (\"v\", \"Value\"), (\"z\", \"Mixed Value\")]:\n"," display(Markdown(f\"#### {activation_name_full} Head Vector Attribution Patching\"))\n"," head_vector_attr_dict[activation_name], head_vector_labels = attr_patch_head_vector(clean_cache, corrupted_cache, corrupted_grad_cache, activation_name)\n"," imshow(head_vector_attr_dict[activation_name], y=head_vector_labels, yaxis=\"Component\", xaxis=\"Position\", title=f\"{activation_name_full} Attribution Patching\")\n"," sum_head_vector_attr = einops.reduce(head_vector_attr_dict[activation_name], \"(layer head) pos -> layer head\", \"sum\", layer=model.cfg.n_layers, head=model.cfg.n_heads)\n"," imshow(sum_head_vector_attr, yaxis=\"Layer\", xaxis=\"Head Index\", title=f\"{activation_name_full} Attribution Patching Sum Over Pos\")"]},{"cell_type":"code","execution_count":21,"metadata":{},"outputs":[{"data":{"text/markdown":["### Head Pattern Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","def stack_head_pattern_from_cache(\n"," cache, \n"," ) -> TT[\"layer_and_head_index\", \"batch\", \"dest_pos\", \"src_pos\"]:\n"," \"\"\"Stacks the head patterns from the cache into a single tensor.\"\"\"\n"," stacked_head_pattern = torch.stack([cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0)\n"," stacked_head_pattern = einops.rearrange(\n"," stacked_head_pattern,\n"," \"layer batch head_index dest_pos src_pos -> (layer head_index) batch dest_pos src_pos\"\n"," )\n"," return stacked_head_pattern\n","\n","def attr_patch_head_pattern(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"dest_pos\", \"src_pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_pattern = stack_head_pattern_from_cache(clean_cache)\n"," corrupted_head_pattern = stack_head_pattern_from_cache(corrupted_cache)\n"," corrupted_grad_head_pattern = stack_head_pattern_from_cache(corrupted_grad_cache)\n"," head_pattern_attr = einops.reduce(\n"," corrupted_grad_head_pattern * (clean_head_pattern - corrupted_head_pattern),\n"," \"component batch dest_pos src_pos -> component dest_pos src_pos\",\n"," \"sum\"\n"," )\n"," return head_pattern_attr, labels\n","\n","head_pattern_attr, labels = attr_patch_head_pattern(clean_cache, corrupted_cache, corrupted_grad_cache)\n","\n","plot_attention_attr(einops.rearrange(head_pattern_attr, \"(layer head) dest src -> layer head dest src\", layer=model.cfg.n_layers, head=model.cfg.n_heads), clean_tokens, index=0, title=\"Head Pattern Attribution Patching\")"]},{"cell_type":"code","execution_count":22,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_head_vector_grad_input_from_grad_cache(\n"," grad_cache: ActivationCache, \n"," activation_name: Literal[\"q\", \"k\", \"v\"],\n"," layer: int\n"," ) -> TT[\"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," vector_grad = grad_cache[activation_name, layer]\n"," ln_scales = grad_cache[\"scale\", layer, \"ln1\"]\n"," attn_layer_object = model.blocks[layer].attn\n"," if activation_name == \"q\":\n"," W = attn_layer_object.W_Q\n"," elif activation_name == \"k\":\n"," W = attn_layer_object.W_K\n"," elif activation_name == \"v\":\n"," W = attn_layer_object.W_V\n"," else:\n"," raise ValueError(\"Invalid activation name\")\n","\n"," return einsum(\"batch pos head_index d_head, batch pos, head_index d_model d_head -> batch pos head_index d_model\", vector_grad, ln_scales.squeeze(-1), W)\n","\n","def get_stacked_head_vector_grad_input(grad_cache, activation_name: Literal[\"q\", \"k\", \"v\"]) -> TT[\"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack([get_head_vector_grad_input_from_grad_cache(grad_cache, activation_name, l) for l in range(model.cfg.n_layers)], dim=0)\n","\n","def get_full_vector_grad_input(grad_cache) -> TT[\"qkv\", \"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack([get_stacked_head_vector_grad_input(grad_cache, activation_name) for activation_name in ['q', 'k', 'v']], dim=0)\n","\n","def attr_patch_head_path(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache\n"," ) -> TT[\"qkv\", \"dest_component\", \"src_component\", \"pos\"]:\n"," \"\"\"\n"," Computes the attribution patch along the path between each pair of heads.\n","\n"," Sets this to zero for the path from any late head to any early head\n","\n"," \"\"\"\n"," start_labels = HEAD_NAMES\n"," end_labels = HEAD_NAMES_QKV\n"," full_vector_grad_input = get_full_vector_grad_input(corrupted_grad_cache)\n"," clean_head_result_stack = clean_cache.stack_head_results(-1)\n"," corrupted_head_result_stack = corrupted_cache.stack_head_results(-1)\n"," diff_head_result = einops.rearrange(\n"," clean_head_result_stack - corrupted_head_result_stack,\n"," \"(layer head_index) batch pos d_model -> layer batch pos head_index d_model\",\n"," layer = model.cfg.n_layers,\n"," head_index = model.cfg.n_heads,\n"," )\n"," path_attr = einsum(\n"," \"qkv layer_end batch pos head_end d_model, layer_start batch pos head_start d_model -> qkv layer_end head_end layer_start head_start pos\", \n"," full_vector_grad_input, \n"," diff_head_result)\n"," correct_layer_order_mask = (\n"," torch.arange(model.cfg.n_layers)[None, :, None, None, None, None] > \n"," torch.arange(model.cfg.n_layers)[None, None, None, :, None, None]).to(path_attr.device)\n"," zero = torch.zeros(1, device=path_attr.device)\n"," path_attr = torch.where(correct_layer_order_mask, path_attr, zero)\n","\n"," path_attr = einops.rearrange(\n"," path_attr,\n"," \"qkv layer_end head_end layer_start head_start pos -> (layer_end head_end qkv) (layer_start head_start) pos\",\n"," )\n"," return path_attr, end_labels, start_labels\n","\n","head_path_attr, end_labels, start_labels = attr_patch_head_path(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(head_path_attr.sum(-1), y=end_labels, yaxis=\"Path End (Head Input)\", x=start_labels, xaxis=\"Path Start (Head Output)\", title=\"Head Path Attribution Patching\")"]},{"cell_type":"markdown","metadata":{},"source":[" This is hard to parse. Here's an experiment with filtering for the most important heads and showing their paths."]},{"cell_type":"code","execution_count":23,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["head_out_values, head_out_indices = head_out_attr.sum(-1).abs().sort(descending=True)\n","line(head_out_values)\n","top_head_indices = head_out_indices[:22].sort().values\n","top_end_indices = []\n","top_end_labels = []\n","top_start_indices = []\n","top_start_labels = []\n","for i in top_head_indices:\n"," i = i.item()\n"," top_start_indices.append(i)\n"," top_start_labels.append(start_labels[i])\n"," for j in range(3):\n"," top_end_indices.append(3*i+j)\n"," top_end_labels.append(end_labels[3*i+j])\n","\n","imshow(head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1), y=top_end_labels, yaxis=\"Path End (Head Input)\", x=top_start_labels, xaxis=\"Path Start (Head Output)\", title=\"Head Path Attribution Patching (Filtered for Top Heads)\")"]},{"cell_type":"code","execution_count":24,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["for j, composition_type in enumerate([\"Query\", \"Key\", \"Value\"]):\n"," imshow(head_path_attr[top_end_indices, :][:, top_start_indices][j::3].sum(-1), y=top_end_labels[j::3], yaxis=\"Path End (Head Input)\", x=top_start_labels, xaxis=\"Path Start (Head Output)\", title=f\"Head Path to {composition_type} Attribution Patching (Filtered for Top Heads)\")"]},{"cell_type":"code","execution_count":25,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["top_head_path_attr = einops.rearrange(head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1), \"(head_end qkv) head_start -> qkv head_end head_start\", qkv=3)\n","imshow(top_head_path_attr, y=[i[:-1] for i in top_end_labels[::3]], yaxis=\"Path End (Head Input)\", x=top_start_labels, xaxis=\"Path Start (Head Output)\", title=f\"Head Path Attribution Patching (Filtered for Top Heads)\", facet_col=0, facet_labels=[\"Query\", \"Key\", \"Value\"])"]},{"cell_type":"markdown","metadata":{},"source":[" Let's now dive into 3 interesting heads: L5H5 (induction head), L8H6 (S-Inhibition Head), L9H9 (Name Mover) and look at their input and output paths (note - Q input means )"]},{"cell_type":"code","execution_count":26,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["interesting_heads = [5 * model.cfg.n_heads + 5, 8 * model.cfg.n_heads + 6, 9 * model.cfg.n_heads + 9]\n","interesting_head_labels = [HEAD_NAMES[i] for i in interesting_heads]\n","for head_index, label in zip(interesting_heads, interesting_head_labels):\n"," in_paths = head_path_attr[3*head_index:3*head_index+3].sum(-1)\n"," out_paths = head_path_attr[:, head_index].sum(-1)\n"," out_paths = einops.rearrange(out_paths, \"(layer_head qkv) -> qkv layer_head\", qkv=3)\n"," all_paths = torch.cat([in_paths, out_paths], dim=0)\n"," all_paths = einops.rearrange(all_paths, \"path_type (layer head) -> path_type layer head\", layer=model.cfg.n_layers, head=model.cfg.n_heads)\n"," imshow(all_paths, facet_col=0, facet_labels=[\"Query (In)\", \"Key (In)\", \"Value (In)\", \"Query (Out)\", \"Key (Out)\", \"Value (Out)\"], title=f\"Input and Output Paths for head {label}\", yaxis=\"Layer\", xaxis=\"Head\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Validating Attribution vs Activation Patching\n"," Let's now compare attribution and activation patching. Generally it's a decent approximation! The main place it fails is MLP0 and the residual stream\n"," My fuzzy intuition is that attribution patching works badly for \"big\" things which are poorly modelled as linear approximations, and works well for \"small\" things which are more like incremental changes. Anything involving replacing the embedding is a \"big\" thing, which includes residual streams, and in GPT-2 small MLP0 seems to be used as an \"extended embedding\" (where later layers use MLP0's output instead of the token embedding), so I also count it as big.\n"," See more discussion in the accompanying blog post!\n"]},{"cell_type":"markdown","metadata":{},"source":[" First do some refactoring to make attribution patching more generic. We make an attribution cache, which is an ActivationCache where each element is (clean_act - corrupted_act) * corrupted_grad, so that it's the per-element attribution for each activation. Thanks to linearity, we just compute things by adding stuff up along the relevant dimensions!"]},{"cell_type":"code","execution_count":27,"metadata":{},"outputs":[],"source":["attribution_cache_dict = {}\n","for key in corrupted_grad_cache.cache_dict.keys():\n"," attribution_cache_dict[key] = corrupted_grad_cache.cache_dict[key] * (clean_cache.cache_dict[key] - corrupted_cache.cache_dict[key])\n","attr_cache = ActivationCache(attribution_cache_dict, model)"]},{"cell_type":"markdown","metadata":{},"source":[" By block: For each head we patch the starting residual stream, attention output + MLP output"]},{"cell_type":"code","execution_count":28,"metadata":{},"outputs":[],"source":["str_tokens = model.to_str_tokens(clean_tokens[0])\n","context_length = len(str_tokens)"]},{"cell_type":"code","execution_count":29,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"95a5290e11b64b6a95ef5dd37d027c7a","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_block_act_patch_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(every_block_act_patch_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Activation Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])"]},{"cell_type":"code","execution_count":30,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_block_every(attr_cache):\n"," resid_pre_attr = einops.reduce(\n"," attr_cache.stack_activation(\"resid_pre\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," attn_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"attn_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," mlp_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"mlp_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n","\n"," every_block_attr_patch_result = torch.stack([resid_pre_attr, attn_out_attr, mlp_out_attr], dim=0)\n"," return every_block_attr_patch_result\n","every_block_attr_patch_result = get_attr_patch_block_every(attr_cache)\n","imshow(every_block_attr_patch_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Attribution Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])"]},{"cell_type":"code","execution_count":31,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(y=every_block_attr_patch_result.reshape(3, -1), x=every_block_act_patch_result.reshape(3, -1), facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Attribution vs Activation Patching Per Block\", xaxis=\"Activation Patch\", yaxis=\"Attribution Patch\", hover=[f\"Layer {l}, Position {p}, |{str_tokens[p]}|\" for l in range(model.cfg.n_layers) for p in range(context_length)], color=einops.repeat(torch.arange(model.cfg.n_layers), \"layer -> (layer pos)\", pos=context_length), color_continuous_scale=\"Portland\")"]},{"cell_type":"markdown","metadata":{},"source":[" By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow."]},{"cell_type":"code","execution_count":32,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"18b2e6b0985b40cd8c0cd1a16ba62975","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=1, zmin=-1)"]},{"cell_type":"code","execution_count":33,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_all_pos_every(attr_cache):\n"," head_out_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_q_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_k_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_v_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack([head_out_all_pos_attr, head_q_all_pos_attr, head_k_all_pos_attr, head_v_all_pos_attr, head_pattern_all_pos_attr])\n"," \n","every_head_all_pos_attr_patch_result = get_attr_patch_attn_head_all_pos_every(attr_cache)\n","imshow(every_head_all_pos_attr_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=1, zmin=-1)"]},{"cell_type":"code","execution_count":34,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(y=every_head_all_pos_attr_patch_result.reshape(5, -1), x=every_head_all_pos_act_patch_result.reshape(5, -1), facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution vs Activation Patching Per Head (All Pos)\", xaxis=\"Activation Patch\", yaxis=\"Attribution Patch\", include_diag=True, hover=head_out_labels, color=einops.repeat(torch.arange(model.cfg.n_layers), \"layer -> (layer head)\", head=model.cfg.n_heads), color_continuous_scale=\"Portland\")"]},{"cell_type":"markdown","metadata":{},"source":[" We see pretty good results in general, but significant errors for heads L5H5 on query and moderate errors for head L10H7 on query and key, and moderate errors for head L11H10 on key. But each of these is fine for pattern and output. My guess is that the problem is that these have pretty saturated attention on a single token, and the linear approximation is thus not great on the attention calculation here, but I'm not sure. When we plot the attention patterns, we do see this!\n"," Note that the axis labels are for the *first* prompt's tokens, but each facet is a different prompt, so this is somewhat inaccurate. In particular, every odd facet has indirect object and subject in the opposite order (IO first). But otherwise everything lines up between the prompts"]},{"cell_type":"code","execution_count":35,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["graph_tok_labels = [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]\n","imshow(clean_cache[\"pattern\", 5][:, 5], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title=\"Attention for Head L5H5\", facet_name=\"Prompt\")\n","imshow(clean_cache[\"pattern\", 10][:, 7], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title=\"Attention for Head L10H7\", facet_name=\"Prompt\")\n","imshow(clean_cache[\"pattern\", 11][:, 10], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title=\"Attention for Head L11H10\", facet_name=\"Prompt\")\n","\n","\n","# [markdown]"]},{"cell_type":"code","execution_count":36,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"06f39489001845849fbc7446a07066f4","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/2160 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_by_pos_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","every_head_by_pos_act_patch_result = einops.rearrange(every_head_by_pos_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n","imshow(every_head_by_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=head_out_labels)"]},{"cell_type":"code","execution_count":37,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_by_pos_every(attr_cache):\n"," head_out_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_q_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_k_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_v_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer dest_pos head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack([head_out_by_pos_attr, head_q_by_pos_attr, head_k_by_pos_attr, head_v_by_pos_attr, head_pattern_by_pos_attr])\n","every_head_by_pos_attr_patch_result = get_attr_patch_attn_head_by_pos_every(attr_cache)\n","every_head_by_pos_attr_patch_result = einops.rearrange(every_head_by_pos_attr_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n","imshow(every_head_by_pos_attr_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=head_out_labels)"]},{"cell_type":"code","execution_count":38,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(y=every_head_by_pos_attr_patch_result.reshape(5, -1), x=every_head_by_pos_act_patch_result.reshape(5, -1), facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution vs Activation Patching Per Head (by Pos)\", xaxis=\"Activation Patch\", yaxis=\"Attribution Patch\", include_diag=True, hover=[f\"{label} {tok}\" for label in head_out_labels for tok in graph_tok_labels], color=einops.repeat(torch.arange(model.cfg.n_layers), \"layer -> (layer head pos)\", head=model.cfg.n_heads, pos = 15), color_continuous_scale=\"Portland\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Factual Knowledge Patching Example\n"," Incomplete, but maybe of interest!\n"," Note that I have better results with the corrupted prompt as having random words rather than Colosseum."]},{"cell_type":"code","execution_count":39,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-xl into HookedTransformer\n","Tokenized prompt: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Paris']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.73 Prob: 95.80% Token: | Paris|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.73\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m95.80\u001b[0m\u001b[1m% Token: | Paris|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.73 Prob: 95.80% Token: | Paris|\n","Top 1th token. Logit: 16.49 Prob: 1.39% Token: | E|\n","Top 2th token. Logit: 14.69 Prob: 0.23% Token: | the|\n","Top 3th token. Logit: 14.58 Prob: 0.21% Token: | É|\n","Top 4th token. Logit: 14.44 Prob: 0.18% Token: | France|\n","Top 5th token. Logit: 14.36 Prob: 0.16% Token: | Mont|\n","Top 6th token. Logit: 13.77 Prob: 0.09% Token: | Le|\n","Top 7th token. Logit: 13.66 Prob: 0.08% Token: | Ang|\n","Top 8th token. Logit: 13.43 Prob: 0.06% Token: | V|\n","Top 9th token. Logit: 13.42 Prob: 0.06% Token: | Stras|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Paris', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Paris'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Tokenized prompt: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Rome']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.02 Prob: 83.70% Token: | Rome|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.02\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m83.70\u001b[0m\u001b[1m% Token: | Rome|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.02 Prob: 83.70% Token: | Rome|\n","Top 1th token. Logit: 17.03 Prob: 4.23% Token: | Naples|\n","Top 2th token. Logit: 16.85 Prob: 3.51% Token: | Pompe|\n","Top 3th token. Logit: 16.14 Prob: 1.73% Token: | Ver|\n","Top 4th token. Logit: 15.87 Prob: 1.32% Token: | Florence|\n","Top 5th token. Logit: 14.77 Prob: 0.44% Token: | Roma|\n","Top 6th token. Logit: 14.68 Prob: 0.40% Token: | Milan|\n","Top 7th token. Logit: 14.66 Prob: 0.39% Token: | ancient|\n","Top 8th token. Logit: 14.37 Prob: 0.29% Token: | Pal|\n","Top 9th token. Logit: 14.30 Prob: 0.27% Token: | Constantinople|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Rome', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Rome'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"}],"source":["gpt2_xl = HookedTransformer.from_pretrained(\"gpt2-xl\")\n","clean_prompt = \"The Eiffel Tower is located in the city of\"\n","clean_answer = \" Paris\"\n","# corrupted_prompt = \"The red brown fox jumps is located in the city of\"\n","corrupted_prompt = \"The Colosseum is located in the city of\"\n","corrupted_answer = \" Rome\"\n","utils.test_prompt(clean_prompt, clean_answer, gpt2_xl)\n","utils.test_prompt(corrupted_prompt, corrupted_answer, gpt2_xl)"]},{"cell_type":"code","execution_count":40,"metadata":{},"outputs":[],"source":["clean_answer_index = gpt2_xl.to_single_token(clean_answer)\n","corrupted_answer_index = gpt2_xl.to_single_token(corrupted_answer)\n","def factual_logit_diff(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return logits[0, -1, clean_answer_index] - logits[0, -1, corrupted_answer_index]"]},{"cell_type":"code","execution_count":41,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 10.634519577026367\n","Corrupted logit diff: -8.988396644592285\n","Clean Metric: tensor(1., device='cuda:0', grad_fn=)\n","Corrupted Metric: tensor(0., device='cuda:0', grad_fn=)\n"]}],"source":["clean_logits, clean_cache = gpt2_xl.run_with_cache(clean_prompt)\n","CLEAN_LOGIT_DIFF_FACTUAL = factual_logit_diff(clean_logits).item()\n","corrupted_logits, _ = gpt2_xl.run_with_cache(corrupted_prompt)\n","CORRUPTED_LOGIT_DIFF_FACTUAL = factual_logit_diff(corrupted_logits).item()\n","\n","def factual_metric(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return (factual_logit_diff(logits) - CORRUPTED_LOGIT_DIFF_FACTUAL) / (CLEAN_LOGIT_DIFF_FACTUAL - CORRUPTED_LOGIT_DIFF_FACTUAL)\n","print(\"Clean logit diff:\", CLEAN_LOGIT_DIFF_FACTUAL)\n","print(\"Corrupted logit diff:\", CORRUPTED_LOGIT_DIFF_FACTUAL)\n","print(\"Clean Metric:\", factual_metric(clean_logits))\n","print(\"Corrupted Metric:\", factual_metric(corrupted_logits))"]},{"cell_type":"code","execution_count":42,"metadata":{},"outputs":[],"source":["# corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(gpt2_xl, corrupted_prompt, factual_metric)"]},{"cell_type":"code","execution_count":43,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Corrupted: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n"]}],"source":["clean_tokens = gpt2_xl.to_tokens(clean_prompt)\n","clean_str_tokens = gpt2_xl.to_str_tokens(clean_prompt)\n","corrupted_tokens = gpt2_xl.to_tokens(corrupted_prompt)\n","corrupted_str_tokens = gpt2_xl.to_str_tokens(corrupted_prompt)\n","print(\"Clean:\", clean_str_tokens)\n","print(\"Corrupted:\", corrupted_str_tokens)"]},{"cell_type":"code","execution_count":44,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"b767eef7a3cd49b9b3cb6e5301463f08","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/48 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def act_patch_residual(clean_cache, corrupted_tokens, model: HookedTransformer, metric):\n"," if len(corrupted_tokens.shape)==2:\n"," corrupted_tokens = corrupted_tokens[0]\n"," residual_patches = torch.zeros((model.cfg.n_layers, len(corrupted_tokens)), device=model.cfg.device)\n"," def residual_hook(resid_pre, hook, layer, pos):\n"," resid_pre[:, pos, :] = clean_cache[\"resid_pre\", layer][:, pos, :]\n"," return resid_pre\n"," for layer in tqdm.tqdm(range(model.cfg.n_layers)):\n"," for pos in range(len(corrupted_tokens)):\n"," patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[(f\"blocks.{layer}.hook_resid_pre\", partial(residual_hook, layer=layer, pos=pos))])\n"," residual_patches[layer, pos] = metric(patched_logits).item()\n"," return residual_patches\n","\n","residual_act_patch = act_patch_residual(clean_cache, corrupted_tokens, gpt2_xl, factual_metric)\n","\n","imshow(residual_act_patch, title=\"Factual Recall Patching (Residual)\", xaxis=\"Position\", yaxis=\"Layer\", x=clean_str_tokens)"]}],"metadata":{"kernelspec":{"display_name":"base","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.7.13"},"orig_nbformat":4,"vscode":{"interpreter":{"hash":"d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"}}},"nbformat":4,"nbformat_minor":2} From d2415b4cd92805da7dd3791bf71618bfde33f751 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Tue, 16 Apr 2024 23:46:52 +0200 Subject: [PATCH 57/73] Demo no position fix (#544) * fixed install version and key name * fixed remaining issues with no position experiment * removed extra key --- demos/No_Position_Experiment.ipynb | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/demos/No_Position_Experiment.ipynb b/demos/No_Position_Experiment.ipynb index d784f2518..98b2ddf2a 100644 --- a/demos/No_Position_Experiment.ipynb +++ b/demos/No_Position_Experiment.ipynb @@ -28,6 +28,11 @@ "# Setup" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, { "cell_type": "code", "execution_count": 1, @@ -39,7 +44,7 @@ "\n", " IN_COLAB = True\n", " !pip install einops\n", - " !pip install https://github.com/neelnanda-io/TransformerLens@no-position-experiment\n", + " %pip install transformer_lens\n", "except:\n", " IN_COLAB = False\n", "\n", @@ -577,7 +582,7 @@ } ], "source": [ - "cache[\"blocks.0.attn.hook_attn\"].shape" + "cache[\"blocks.0.attn.hook_pattern\"].shape" ] }, { @@ -717,7 +722,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -733,7 +738,7 @@ "logit_components = (\n", " resid_stack[:, batch_index]\n", " @ fold_W_U\n", - " / cache[\"scale\", None, \"ln_final\"][batch_index]\n", + " / cache[\"scale\"][batch_index]\n", ")\n", "print(logit_components.shape)" ] @@ -1274,7 +1279,7 @@ "losses = []\n", "loss_labels = []\n", "for hook_name in hook_list:\n", - " if hook_name != \"hook_pos_embed\" and \"result\" not in hook_name:\n", + " if hook_name in cache and hook_name != \"hook_pos_embed\" and \"result\" not in hook_name:\n", " average_act = cache[hook_name].mean(0)\n", "\n", " def replacing_with_average_act(activation, hook):\n", From f22a4067026c9a9711b34970d0fc9a03e1c8d848 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 17 Apr 2024 00:27:18 +0200 Subject: [PATCH 58/73] Othello colab fix (#545) * fixed install version and key name * fixed remaining issues with no position experiment * removed extra key * fixed othello in colab --- demos/Othello_GPT.ipynb | 1 + 1 file changed, 1 insertion(+) diff --git a/demos/Othello_GPT.ipynb b/demos/Othello_GPT.ipynb index 7cbb68a5c..1b4400bc7 100644 --- a/demos/Othello_GPT.ipynb +++ b/demos/Othello_GPT.ipynb @@ -69,6 +69,7 @@ " print(\"Running as a Colab notebook\")\n", " %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n", " %pip install circuitsvis\n", + " %pip install torchtyping\n", " \n", " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", " # # Install another version of node that makes PySvelte work way faster\n", From 65df48bb49488e3bb2e566e747059aa404679094 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 17 Apr 2024 00:59:39 +0200 Subject: [PATCH 59/73] fixed demo for current colab (#546) --- demos/{santacoder.ipynb => Santa_Coder.ipynb} | 1 + 1 file changed, 1 insertion(+) rename demos/{santacoder.ipynb => Santa_Coder.ipynb} (99%) diff --git a/demos/santacoder.ipynb b/demos/Santa_Coder.ipynb similarity index 99% rename from demos/santacoder.ipynb rename to demos/Santa_Coder.ipynb index 61b455035..a69071c38 100644 --- a/demos/santacoder.ipynb +++ b/demos/Santa_Coder.ipynb @@ -35,6 +35,7 @@ " print(\"Running as a Colab notebook\")\n", " %pip install git+https://github.com/neelnanda-io/TransformerLens.git``\n", " %pip install circuitsvis\n", + " %pip install torchtyping\n", " \n", " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", " # # Install another version of node that makes PySvelte work way faster\n", From 0e86253faac86c8bb5eaecad8ea9eb3e91172b63 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 24 Apr 2024 02:37:23 +0200 Subject: [PATCH 60/73] Hf token auth (#550) * added optional token to transfomers loading * added secret for make docs command * ran format * added gated models instructions * rearranged env setting * moved hf token * added temporary log * changed secret reference * changed env variable reference * changed token reference * changed back to secrets reference * removed microsoft models from remote code list * updated token again --- .github/workflows/gh-pages.yml | 4 +++- docs/source/content/getting_started.md | 10 +++++++++ transformer_lens/HookedEncoder.py | 7 +++++- transformer_lens/HookedTransformer.py | 3 +++ transformer_lens/loading_from_pretrained.py | 24 ++++++++++++++++----- transformer_lens/utils.py | 3 +++ 6 files changed, 44 insertions(+), 7 deletions(-) diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml index a5f6927a7..220dfc093 100644 --- a/.github/workflows/gh-pages.yml +++ b/.github/workflows/gh-pages.yml @@ -33,7 +33,9 @@ jobs: - name: Install dependencies run: poetry install --with docs - name: Build Docs - run: poetry run build-docs + run: HF_TOKEN="$HF_TOKEN" poetry run build-docs + env: + HF_TOKEN: "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" - name: Upload Docs Artifact uses: actions/upload-artifact@v3 with: diff --git a/docs/source/content/getting_started.md b/docs/source/content/getting_started.md index 459b65b44..13952cd5e 100644 --- a/docs/source/content/getting_started.md +++ b/docs/source/content/getting_started.md @@ -19,3 +19,13 @@ One significant design decision made was to have a single transformer implementa Import the library with `import transformer_lens` (Note: This library used to be known as EasyTransformer, and some breaking changes have been made since the rename. If you need to use the old version with some legacy code, run `pip install git+https://github.com/neelnanda-io/TransformerLens@v1`.) + +## Huggingface Gated Access + +Some of the models available in TransformerLens require gated access to be used. Luckily TransformerLens provides a way to access those models via the configuration of an environmental variable. Simply configure your access token found [here](https://huggingface.co/settings/tokens) as `HF_TOKEN` in your environment. + +You will need to make sure you accept the agreements for any gated models, but once you do, the models will work with TransformerLens without issue. If you attempt to ues one of these models before you have accepted any related agreements, the console output will be very helpful and point you to the URL where you need to accept an agreement. As of 23/4/24, the current list of gated models supported by TransformerLens is as follows. + +* https://huggingface.co/mistralai/Mixtral-8x7B-v0.1 +* https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 +* https://huggingface.co/mistralai/Mistral-7B-v0.1 diff --git a/transformer_lens/HookedEncoder.py b/transformer_lens/HookedEncoder.py index cc0d4e880..59ede19af 100644 --- a/transformer_lens/HookedEncoder.py +++ b/transformer_lens/HookedEncoder.py @@ -7,6 +7,7 @@ from __future__ import annotations import logging +import os from typing import Dict, List, Optional, Tuple, Union, cast, overload import torch @@ -52,7 +53,11 @@ def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs): if tokenizer is not None: self.tokenizer = tokenizer elif self.cfg.tokenizer_name is not None: - self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.tokenizer_name) + huggingface_token = os.environ.get("HF_TOKEN", None) + self.tokenizer = AutoTokenizer.from_pretrained( + self.cfg.tokenizer_name, + token=huggingface_token, + ) else: self.tokenizer = None diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index dc65c5c10..104f4feae 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -10,6 +10,7 @@ """ import logging +import os from typing import Dict, List, NamedTuple, Optional, Tuple, Union, cast, overload import einops @@ -140,12 +141,14 @@ def __init__( # should be False if "phi" in self.cfg.tokenizer_name.lower(): use_fast = False + huggingface_token = os.environ.get("HF_TOKEN", None) self.set_tokenizer( AutoTokenizer.from_pretrained( self.cfg.tokenizer_name, add_bos_token=True, trust_remote_code=self.cfg.trust_remote_code, use_fast=use_fast, + token=huggingface_token, ), default_padding_side=default_padding_side, ) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index d51295780..7a87c64ba 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -5,6 +5,7 @@ import dataclasses import logging +import os import re from typing import Dict, Optional, Union, cast @@ -610,8 +611,6 @@ NEED_REMOTE_CODE_MODELS = ( "bigcode/santacoder", "Qwen/Qwen-", - "microsoft/phi-1", - "microsoft/phi-1_5", "microsoft/phi-2", ) @@ -659,7 +658,12 @@ def convert_hf_model_config(model_name: str, **kwargs): elif "gemma" in official_model_name.lower(): architecture = "GemmaForCausalLM" else: - hf_config = AutoConfig.from_pretrained(official_model_name, **kwargs) + huggingface_token = os.environ.get("HF_TOKEN", None) + hf_config = AutoConfig.from_pretrained( + official_model_name, + token=huggingface_token, + **kwargs, + ) architecture = hf_config.architectures[0] if official_model_name.startswith( ("llama-7b", "meta-llama/Llama-2-7b") @@ -1378,11 +1382,13 @@ def get_pretrained_state_dict( return state_dict else: if cfg.from_checkpoint: + huggingface_token = os.environ.get("HF_TOKEN", None) if official_model_name.startswith("stanford-crfm"): hf_model = AutoModelForCausalLM.from_pretrained( official_model_name, revision=f"checkpoint-{cfg.checkpoint_value}", torch_dtype=dtype, + token=huggingface_token, **kwargs, ) elif official_model_name.startswith("EleutherAI/pythia"): @@ -1390,20 +1396,28 @@ def get_pretrained_state_dict( official_model_name, revision=f"step{cfg.checkpoint_value}", torch_dtype=dtype, + token=huggingface_token, **kwargs, ) else: raise ValueError(f"Checkpoints for model {official_model_name} are not supported") elif hf_model is None: + huggingface_token = os.environ.get("HF_TOKEN", None) if official_model_name in NON_HF_HOSTED_MODEL_NAMES: raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model") elif "bert" in official_model_name: hf_model = BertForPreTraining.from_pretrained( - official_model_name, torch_dtype=dtype, **kwargs + official_model_name, + torch_dtype=dtype, + token=huggingface_token, + **kwargs, ) else: hf_model = AutoModelForCausalLM.from_pretrained( - official_model_name, torch_dtype=dtype, **kwargs + official_model_name, + torch_dtype=dtype, + token=huggingface_token, + **kwargs, ) # Load model weights, and fold in layer norm weights diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 1b33f99fa..fe9f8d128 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -7,6 +7,7 @@ import inspect import json +import os import re import shutil from copy import deepcopy @@ -1120,9 +1121,11 @@ def get_tokenizer_with_bos(tokenizer): if add_bos_token: tokenizer_with_bos = tokenizer else: + huggingface_token = os.environ.get("HF_TOKEN", None) tokenizer_with_bos = AutoTokenizer.from_pretrained( pretrained_model_name_or_path, add_bos_token=True, + token=huggingface_token, **init_kwargs, ) From d8270c8193d5b638563245983b561131d35fc5b8 Mon Sep 17 00:00:00 2001 From: Clement Dumas Date: Wed, 24 Apr 2024 02:37:55 +0200 Subject: [PATCH 61/73] Fixed device being set to cpu:0 instead of cpu (#551) --- transformer_lens/utilities/devices.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_lens/utilities/devices.py b/transformer_lens/utilities/devices.py index 40bc0fbd9..c8e5b78b7 100644 --- a/transformer_lens/utilities/devices.py +++ b/transformer_lens/utilities/devices.py @@ -38,6 +38,8 @@ def get_device_for_block_index( if device is None: device = cfg.device device = torch.device(device) + if device.type == "cpu": + return device device_index = (device.index or 0) + (index // layers_per_device) return torch.device(device.type, device_index) From 2092dc911e30a3acffcbed202a03b53584faa16c Mon Sep 17 00:00:00 2001 From: Joel Burget Date: Tue, 23 Apr 2024 17:52:36 -0700 Subject: [PATCH 62/73] Add support for Llama 3 (and Llama-2-70b-hf) (#549) * Start work on adding llama. * Remove v2 from arxiv URL. * Remove llama special case (breaks because hf_config is not defined). * Remove TODO. llama-2-70b-hf and Llama 3 models all have n_key_value_heads set so they'll use Grouped-Query Attention. * Add back check for non-hf-hosted models. * Hardcode Llama-3 configs. See discussion on https://github.com/neelnanda-io/TransformerLens/pull/549 for why. --------- Co-authored-by: Bryce Meyer --- transformer_lens/components.py | 2 +- transformer_lens/loading_from_pretrained.py | 46 ++++++++++++++++++++- 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/transformer_lens/components.py b/transformer_lens/components.py index b61d5bc0b..b419540b7 100644 --- a/transformer_lens/components.py +++ b/transformer_lens/components.py @@ -961,7 +961,7 @@ def __init__( attn_type: str = "global", layer_id: Union[int, None] = None, ): - """Grouped Query Attention Block - see https://arxiv.org/abs/2305.13245v2 for details. + """Grouped Query Attention Block - see https://arxiv.org/abs/2305.13245 for details. Similar to regular attention, W_Q, W_K, and W_V all have shape [head_index, d_model, d_head] and W_Q has shape [head_index, d_head, d_model]. However, under the hood the key and value weights _W_K and _W_V are stored with shape [n_key_value_heads, d_model, d_head] and are expanded when the corresponding properties' getter is called. Similarly, during a forward pass, initially K and V are kept in shapes [batch, pos, n_key_value_heads, d_head] and will only be expanded to shapes [batch, pos, n_heads, d_head] diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 7a87c64ba..08ff472fe 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -121,7 +121,10 @@ "CodeLlama-7b-hf", "CodeLlama-7b-Python-hf", "CodeLlama-7b-Instruct-hf", - # TODO Llama-2-70b-hf requires Grouped-Query Attention, see the paper https://arxiv.org/pdf/2307.09288.pdf + "meta-llama/Meta-Llama-3-8B", + "meta-llama/Meta-Llama-3-8B-Instruct", + "meta-llama/Meta-Llama-3-70B", + "meta-llama/Meta-Llama-3-70B-Instruct", "Baidicoot/Othello-GPT-Transformer-Lens", "bert-base-cased", "roneneldan/TinyStories-1M", @@ -601,7 +604,7 @@ "llama-30b-hf", "llama-65b-hf", ] -"""Official model names for models that not hosted on HuggingFace.""" +"""Official model names for models not hosted on HuggingFace.""" # Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases DEFAULT_MODEL_ALIASES = [ @@ -665,6 +668,7 @@ def convert_hf_model_config(model_name: str, **kwargs): **kwargs, ) architecture = hf_config.architectures[0] + if official_model_name.startswith( ("llama-7b", "meta-llama/Llama-2-7b") ): # same architecture for LLaMA and Llama-2 @@ -781,6 +785,44 @@ def convert_hf_model_config(model_name: str, **kwargs): "final_rms": True, "gated_mlp": True, } + elif "Meta-Llama-3-8B" in official_model_name: + cfg_dict = { + "d_model": 4096, + "d_head": 128, + "n_heads": 32, + "d_mlp": 14336, + "n_layers": 32, + "n_ctx": 8192, + "eps": 1e-5, + "d_vocab": 128256, + "act_fn": "silu", + "n_key_value_heads": 8, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "rotary_adjacent_pairs": False, + "rotary_dim": 128, + "final_rms": True, + "gated_mlp": True, + } + elif "Meta-Llama-3-70B" in official_model_name: + cfg_dict = { + "d_model": 8192, + "d_head": 128, + "n_heads": 64, + "d_mlp": 28672, + "n_layers": 80, + "n_ctx": 8192, + "eps": 1e-5, + "d_vocab": 128256, + "act_fn": "silu", + "n_key_value_heads": 8, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "rotary_adjacent_pairs": False, + "rotary_dim": 128, + "final_rms": True, + "gated_mlp": True, + } elif architecture == "GPTNeoForCausalLM": cfg_dict = { "d_model": hf_config.hidden_size, From fe89b042f28ccb7fd07a0f74e0bc00d9f59f5f0d Mon Sep 17 00:00:00 2001 From: Sergii Kharagorgiev Date: Wed, 24 Apr 2024 04:09:38 +0300 Subject: [PATCH 63/73] Loading of huggingface 4-bit quantized Llama (#486) * working demo of 4bit quantized Llama * add memory info to the demo * cleanup, asserts for quantization * hooks reading/writing * test in colab; do not import Int8Params * add some comments * format; fix optional argument use * merge with main * format * ran format * locked attribution patching to 1.1.1 * fixed demo for current colab * minor typing fixes for mypy * fixing typing issue * removing extra W_Q W_O * ignored merge artifacts & push for proper CI run --------- Co-authored-by: Bryce Meyer Co-authored-by: hannamw --- .gitignore | 1 + demos/LLaMA2_GPU_quantized.ipynb | 4806 +++++++++++++++++++ transformer_lens/HookedTransformer.py | 30 +- transformer_lens/HookedTransformerConfig.py | 3 + transformer_lens/components.py | 269 +- transformer_lens/loading_from_pretrained.py | 33 +- 6 files changed, 5066 insertions(+), 76 deletions(-) create mode 100644 demos/LLaMA2_GPU_quantized.ipynb diff --git a/.gitignore b/.gitignore index 131658efd..61589404d 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ docs/build .Ds_Store .pylintrc docs/source/generated +**.orig diff --git a/demos/LLaMA2_GPU_quantized.ipynb b/demos/LLaMA2_GPU_quantized.ipynb new file mode 100644 index 000000000..58631a21e --- /dev/null +++ b/demos/LLaMA2_GPU_quantized.ipynb @@ -0,0 +1,4806 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "EyASOtpeCUsO" + }, + "source": [ + "# LLaMA and Llama-2 in TransformerLens" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QnUOM0-RCUsO" + }, + "source": [ + "## Setup (skip)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HssVtL08CUsP", + "outputId": "5ad91c32-95e8-4970-99ec-242f9e2ebab2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (0.1.99)\n" + ] + } + ], + "source": [ + "%pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8\n", + "%pip install sentencepiece # Llama tokenizer requires sentencepiece" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "DawCWbiaCUsR", + "outputId": "3f527879-cbd3-42b5-8e72-ba70dc906d79" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running as a Colab notebook\n", + "Collecting git+https://github.com/coolvision/TransformerLens.git@llama_4bit_v2\n", + " Cloning https://github.com/coolvision/TransformerLens.git (to revision llama_4bit_v2) to /tmp/pip-req-build-lpt2rmoh\n", + " Running command git clone --filter=blob:none --quiet https://github.com/coolvision/TransformerLens.git /tmp/pip-req-build-lpt2rmoh\n", + " Running command git checkout -b llama_4bit_v2 --track origin/llama_4bit_v2\n", + " Switched to a new branch 'llama_4bit_v2'\n", + " Branch 'llama_4bit_v2' set up to track remote branch 'llama_4bit_v2' from 'origin'.\n", + " Resolved https://github.com/coolvision/TransformerLens.git to commit b2b80cb92f4aa6d63a456196f0c3472b3d34c6eb\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: accelerate>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.26.1)\n", + "Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.14.1)\n", + "Requirement already satisfied: datasets>=2.7.1 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (2.16.1)\n", + "Requirement already satisfied: einops>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.7.0)\n", + "Requirement already satisfied: fancy-einsum>=0.0.3 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.0.3)\n", + "Requirement already satisfied: jaxtyping>=0.2.11 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.2.25)\n", + "Requirement already satisfied: numpy>=1.24 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (1.26.3)\n", + "Requirement already satisfied: pandas>=1.1.5 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (1.5.3)\n", + "Requirement already satisfied: rich>=12.6.0 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (13.7.0)\n", + "Requirement already satisfied: torch!=2.0,!=2.1.0,>=1.10 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (2.1.2)\n", + "Requirement already satisfied: tqdm>=4.64.1 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (4.66.1)\n", + "Requirement already satisfied: transformers>=4.25.1 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (4.35.2)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (4.5.0)\n", + "Requirement already satisfied: wandb>=0.13.5 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.16.2)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.23.0->transformer-lens==0.0.0) (23.2)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.23.0->transformer-lens==0.0.0) (5.9.5)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.23.0->transformer-lens==0.0.0) (6.0.1)\n", + "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.23.0->transformer-lens==0.0.0) (0.20.2)\n", + "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.23.0->transformer-lens==0.0.0) (0.4.1)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (3.13.1)\n", + "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (10.0.1)\n", + "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (0.6)\n", + "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (0.3.7)\n", + "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (2.31.0)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (3.4.1)\n", + "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (0.70.15)\n", + "Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (2023.6.0)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (3.9.1)\n", + "Requirement already satisfied: typeguard<3,>=2.13.3 in /usr/local/lib/python3.10/dist-packages (from jaxtyping>=0.2.11->transformer-lens==0.0.0) (2.13.3)\n", + "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.1.5->transformer-lens==0.0.0) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.1.5->transformer-lens==0.0.0) (2023.3.post1)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=12.6.0->transformer-lens==0.0.0) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=12.6.0->transformer-lens==0.0.0) (2.16.1)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (3.2.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (3.1.3)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (2.18.1)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (12.1.105)\n", + "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (2.1.0)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (12.3.101)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.25.1->transformer-lens==0.0.0) (2023.6.3)\n", + "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.25.1->transformer-lens==0.0.0) (0.15.0)\n", + "Requirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (8.1.7)\n", + "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (3.1.41)\n", + "Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (1.39.2)\n", + "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (0.4.0)\n", + "Requirement already satisfied: setproctitle in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (1.3.3)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (67.7.2)\n", + "Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (1.4.4)\n", + "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (3.20.3)\n", + "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb>=0.13.5->transformer-lens==0.0.0) (1.16.0)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (23.2.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (6.0.4)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (1.9.4)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (1.4.1)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (1.3.1)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (4.0.3)\n", + "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens==0.0.0) (4.0.11)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=12.6.0->transformer-lens==0.0.0) (0.1.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer-lens==0.0.0) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer-lens==0.0.0) (3.6)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer-lens==0.0.0) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer-lens==0.0.0) (2023.11.17)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (2.1.3)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch!=2.0,!=2.1.0,>=1.10->transformer-lens==0.0.0) (1.3.0)\n", + "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens==0.0.0) (5.0.1)\n", + "Requirement already satisfied: circuitsvis in /usr/local/lib/python3.10/dist-packages (1.43.2)\n", + "Requirement already satisfied: importlib-metadata>=5.1.0 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (7.0.1)\n", + "Requirement already satisfied: numpy>=1.24 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (1.26.3)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (2.18.1)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: torch>=1.10 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (2.1.2)\n", + "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from circuitsvis) (2.1.0)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->circuitsvis) (12.3.101)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from triton==2.1.0->circuitsvis) (3.13.1)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.10/dist-packages (from importlib-metadata>=5.1.0->circuitsvis) (3.17.0)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->circuitsvis) (4.5.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->circuitsvis) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->circuitsvis) (3.2.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->circuitsvis) (3.1.3)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->circuitsvis) (2023.6.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10->circuitsvis) (2.1.3)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10->circuitsvis) (1.3.0)\n" + ] + } + ], + "source": [ + "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "DEVELOPMENT_MODE = False\n", + "IN_VSCODE = False\n", + "try:\n", + " import google.colab\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + " # %pip install git+https://github.com/neelnanda-io/TransformerLens.git``\n", + " %pip install git+https://github.com/coolvision/TransformerLens.git@llama_4bit_v2``\n", + " %pip install circuitsvis\n", + "\n", + " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", + " # # Install another version of node that makes PySvelte work way faster\n", + " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", + " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", + "except:\n", + " IN_COLAB = False\n", + " print(\"Running as a Jupyter notebook - intended for development only!\")\n", + " from IPython import get_ipython\n", + "\n", + " ipython = get_ipython()\n", + " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", + " ipython.magic(\"load_ext autoreload\")\n", + " ipython.magic(\"autoreload 2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "_OukKSMfCUsR", + "outputId": "27a2a59f-e635-4b80-c759-f00d542352bd" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using renderer: colab\n" + ] + } + ], + "source": [ + "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", + "import plotly.io as pio\n", + "if IN_COLAB or not DEVELOPMENT_MODE:\n", + " pio.renderers.default = \"colab\"\n", + "elif IN_VSCODE:\n", + " pio.renderers.default = \"notebook_connected\"\n", + "print(f\"Using renderer: {pio.renderers.default}\")\n", + "\n", + "import circuitsvis as cv" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "P8zS3MPkCUsR" + }, + "outputs": [], + "source": [ + "# Import stuff\n", + "import torch\n", + "import tqdm.auto as tqdm\n", + "import plotly.express as px\n", + "\n", + "from transformers import LlamaForCausalLM, LlamaTokenizer\n", + "from tqdm import tqdm\n", + "from jaxtyping import Float\n", + "\n", + "import transformer_lens\n", + "import transformer_lens.utils as utils\n", + "from transformer_lens.hook_points import (\n", + " HookPoint,\n", + ") # Hooking utilities\n", + "from transformer_lens import HookedTransformer\n", + "\n", + "torch.set_grad_enabled(False)\n", + "\n", + "def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n", + " px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=\"RdBu\", labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n", + "\n", + "def line(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n", + " px.line(utils.to_numpy(tensor), labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n", + "\n", + "def scatter(x, y, xaxis=\"\", yaxis=\"\", caxis=\"\", renderer=None, **kwargs):\n", + " x = utils.to_numpy(x)\n", + " y = utils.to_numpy(y)\n", + " px.scatter(y=y, x=x, labels={\"x\":xaxis, \"y\":yaxis, \"color\":caxis}, **kwargs).show(renderer)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iXCfIfKKCUsS", + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "## Loading LLaMA" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QH3kyhzFCUsS" + }, + "source": [ + "LLaMA weights are not available on HuggingFace, so you'll need to download and convert them\n", + "manually:\n", + "\n", + "1. Get LLaMA weights here: https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform\n", + "\n", + "2. Convert the official weights to huggingface:\n", + "\n", + "```bash\n", + "python src/transformers/models/llama/convert_llama_weights_to_hf.py \\\n", + " --input_dir /path/to/downloaded/llama/weights \\\n", + " --model_size 7B \\\n", + " --output_dir /llama/weights/directory/\n", + "```\n", + "\n", + "Note: this didn't work for Arthur by default (even though HF doesn't seem to show this anywhere). I\n", + "had to change this\n", + "line of my pip installed `src/transformers/models/llama/convert_llama_weights_to_hf.py` file (which\n", + "was found at\n", + "`/opt/conda/envs/arthurenv/lib/python3.10/site-packages/transformers/models/llama/convert_llama_weights_to_hf.py`)\n", + "from `input_base_path=os.path.join(args.input_dir, args.model_size),` to `input_base_path=os.path.join(args.input_dir),`\n", + "\n", + "3. Change the ```MODEL_PATH``` variable in the cell below to where the converted weights are stored." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "RdJ0AuW_CUsS" + }, + "outputs": [], + "source": [ + "# MODEL_PATH=''\n", + "\n", + "# tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)\n", + "# hf_model = LlamaForCausalLM.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True)\n", + "\n", + "# model = HookedTransformer.from_pretrained(\"llama-7b\", hf_model=hf_model, device=\"cpu\", fold_ln=False, center_writing_weights=False, center_unembed=False, tokenizer=tokenizer)\n", + "\n", + "# model = model.to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "# model.generate(\"The capital of Germany is\", max_new_tokens=20, temperature=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UmOqXE9wCUsS" + }, + "source": [ + "## Loading LLaMA-2\n", + "LLaMA-2 is hosted on HuggingFace, but gated by login.\n", + "\n", + "Before running the notebook, log in to HuggingFace via the cli on your machine:\n", + "```bash\n", + "transformers-cli login\n", + "```\n", + "This will cache your HuggingFace credentials, and enable you to download LLaMA-2." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KH6evHq1GGQi" + }, + "source": [ + "## Install additional dependenceis requred for quantization" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "n26wTL_3GYAO", + "outputId": "da381126-148a-43f8-8506-1990d39317f6" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: bitsandbytes in /usr/local/lib/python3.10/dist-packages (0.42.0)\n", + "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from bitsandbytes) (1.11.4)\n", + "Requirement already satisfied: numpy<1.28.0,>=1.21.6 in /usr/local/lib/python3.10/dist-packages (from scipy->bitsandbytes) (1.26.3)\n", + "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.26.1)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate) (1.26.3)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (23.2)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate) (6.0.1)\n", + "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.1.2)\n", + "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.20.2)\n", + "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.4.1)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.13.1)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (4.5.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.2.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.3)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2023.6.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.18.1)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (12.1.105)\n", + "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.1.0)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.10.0->accelerate) (12.3.101)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate) (2.31.0)\n", + "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate) (4.66.1)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (3.6)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (2023.11.17)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n" + ] + } + ], + "source": [ + "%pip install bitsandbytes\n", + "%pip install accelerate" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iOntzU3lGZA6" + }, + "source": [ + "## Load quantized model" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 406, + "referenced_widgets": [ + "27ae531ccf1848b79a636a731bc635b9", + "cc67c0d024c848898c5298691db43ed6", + "0e3321d5dcc74ce09baeadce8bd5d6e1", + "4b3dc4ad2d2f4ac59612b6392a6a77f7", + "cfb7c854917e4e3abca3158c0554c081", + "fc7ac85b780a4605809bf68f006e1ef3", + "2582ee70618844ff8168b8637ddbbae4", + "e6b8c1942a54415f998e1f44e796ce1b", + "2c9f2db0723d49fbb902d658703cdcb6", + "37618acd36d44f68b0dbe07cedd60111", + "a31c3597f0a546b0b47bbb873b7c3848", + "1b918c0947a54ae89af225ec372351bf", + "49502d07f7b143e784f4bf76ed6cc272", + "9a3264f31b044a73be31ca4fba3e7f1d", + "0eaef657be2742b1867943854370d1ed", + "e31ca54d82f64114ab694ce50ba62987", + "96e00b8ba2e640d79cd80a1c401dcedd", + "537df88aa48b4ae883ec9f1b86a82d11", + "55f5712956834686ba2b19a36438b726", + "f4a21799b8f84e4687e36314cfc873f0", + "143ecc2bee0c42489185fbbe8c785578", + "fb16992539ed404383974588a5c9263a", + "4e17ea3af1b94b52944a816aba2525ba", + "186ff73f09504161b24be544f4707858", + "24670f0283144c7f97d2ba4733bda25d", + "70a8f83e1235458ab87e69a1c05780a4", + "c09f3c527dce4bed8f6d72922218317e", + "e850e6dcd47f4da2842e5e2b8c8a43c5", + "744091704d0347d6ad0a6032ee362257", + "843f9fad9fea4db88e4201c4aad858b9", + "c0173c0a1f334f6fb3c1e8e0163da260", + "f733175715f640668ae119c03542fcb5", + "09d337ea6f0945c79bad129283c85f0b", + "aec34d7e6c4444778daa42043fb595f2", + "99e864d7de4d461fbb8ccc3f4f765645", + "368b7e51e67749cc8952784c042e3410", + "4ff270b1c5dd484996ce971ed7014317", + "5f9f9e24d6834de3a058a63fecd3a19f", + "6c8c9155c1f04f0a97cd996e741446c0", + "fb3f514d931e4c8192f7cd3557ece592", + "a8d929695aee466fbdc0324dab55aa15", + "d27922254bcb4c30aebff5da1ee34401", + "838dfdb7cec44c86bf41608c17f9852b", + "1720ff23ff344e498c6f80a98bbd2534", + "0f43034866264e019a364912d63ec11e", + "7f0d44e575034332bc737cf702a79468", + "201941842a4044148b6fc36e18fe3187", + "0b83c8135ef44951964738be07783832", + "d63f72800b854c7c980ed08f77dd61bf", + "18f4b44f74ff4b3c9198a1ecc8d672a8", + "1d2bfdbba1dd4e3db05a1a0ce052fad7", + "236fbd6f27c14dcba6ec274fe34852ca", + "96bc194611d2414c979340ad2b6416bf", + "2255bd5f2c6841bd9af4bc8963945936", + "687a8ab9cfcc45619d453fef2c8d04ea", + "739dde14dc7e4e46aff9ee8858f0f334", + "6624e2d4d0a04e858c08a709f6bcf31d", + "fdb572e99dde402199a43ac369891538", + "cf227ad8761f4f80a7884ec915add0f4", + "3951d93995c345f3a001ec7fb9a63df1", + "bbad5c1e5c4d46a49e500d5b3890ad0b", + "cca945b5469b41549da4b2bd369c83ff", + "2b903ea4e2934e4ca655296945bea0cd", + "bef5c99c8ba3495b85454d8f10ff17bd", + "0604d0239806455889c09137ebee2815", + "9a3e9cc371da4801ab642ea09101a40f", + "b7daa28812e54fd6941f33d1b8325666", + "f660c1ba3a664866a5abcdd3c35d0e72", + "707287071f58451b8e5da6688cba286b", + "21f81e0c00e94b5d82039aca95e44bfb", + "71a49a876cc1413fad925c15672d5919", + "1174909ea64c4afa9ba244564eeafaf5", + "a76b6b9e474a4cf79c98b52758adeab5", + "4abc540a7c404642b7a988937efdd196", + "d98b1abde30a4dd28468c7c12d76c822", + "ce270b54645e4f1d87a0a10416739c1f", + "056409f93e7f48f9a9b8556ed27fcd08", + "d5bb62948e5545319030671e17be5ced", + "e354464414fd479388da58d033700024", + "07b6cec0de3040efa472867d07cd0495", + "2292ea67682e44bc8c61ce31bd18371c", + "a6182b17067c431fb933d36cef4d4438", + "eb69099d7fd9416881bdc9635838da1b", + "6f9ff8cbf6b3427ab3d30aa19b7703db", + "162b1719fad54fcabe0c9b0646956985", + "e2a836342061410c86827b364d3feb29", + "9d350966e16c409eb32e0642ee908f24", + "cce9b0e3791e442a9e7b9b87a3ffec41", + "b525a8f8d07b4abcb162fcca0bfab28d", + "e86dea79f86143cab0b88c2d5f8992e8", + "9d0f7c1505c9436ca8c04724381dd70f", + "dd5b7da397c849adabcf33c2b7c3aeb8", + "1be40306cb6844319d44ba1f4f0164cb", + "7b2d516d0c5046dc816801790aa4c2b5", + "9e1c77e44d884d4d80684179c6f4c96c", + "bb637c634d6f4efebafa3aa503c9862e", + "8b28096911874bae9798969414e385de", + "1d118df455dd4d6aa395df422c711573", + "e990b322c8104b329227207e442003ee", + "22c5623f724647689c50a0e7c37cb371", + "fceb5194627442d79ee5ecd91ee01a10", + "76516b2b4e104548a4b892f3979f41c8", + "54ffa55dfd154b8a83edf45c157bec11", + "042a11af13fe4a1ab7c2a3beef4ee1ff", + "51664a883bcb4d52b55e8826f17726bb", + "951f41fc154c4f8aa38124c11e481223", + "96f61db2adf54c45bf60630919a67c95", + "a055f35c82bb43cbb2ff3f8a98fa7a90", + "da72027fa75c4687bee46ab832297882", + "251d31dade3348acaec1eb247ebb33ab", + "d6603805bb654ea9a7e0ae2b60dd5be6", + "76e5f131438d4086bfcf52f7f4969b3a", + "19f9d3421cef4c38abb500910e251f22", + "c3520f624fd043bfaac26503ff10f254", + "f4e419873d6f49a99b8718666811073e", + "6485458203c946bba61f5ec96959abfb", + "c4bcdc27cad7466e8bf0f8655373a71b", + "c92a8c987de548239ec78c93ee5bf660", + "1dc24cb487dc4b919bbd3c7e6a89e150", + "1c861f3ec5214c148a714f04bdd9696c", + "92f27d77468c407aae1208893996a2a7" + ] + }, + "id": "urpZu9jECUsT", + "outputId": "d4e9217b-a099-4148-89ba-cd790dcde7e5" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "27ae531ccf1848b79a636a731bc635b9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "model.safetensors.index.json: 0%| | 0.00/26.8k [00:00\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "llama_text = \"Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets.\"\n", + "llama_tokens = model.to_tokens(llama_text)\n", + "llama_logits, llama_cache = model.run_with_cache(llama_tokens, remove_batch_dim=True)\n", + "\n", + "attention_pattern = llama_cache[\"pattern\", 0, \"attn\"]\n", + "llama_str_tokens = model.to_str_tokens(llama_text)\n", + "\n", + "print(\"Layer 0 Head Attention Patterns:\")\n", + "display(cv.attention.attention_patterns(tokens=llama_str_tokens, attention=attention_pattern))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0f_t5QZeCUsU" + }, + "source": [ + "### Writing to hooks" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "YVHW-VKJCUsU", + "outputId": "346b37b7-0bfd-4dcf-e21c-fa930dec14c1" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Shape of the value tensor: torch.Size([1, 34, 32, 128])\n", + "Original Loss: 2.841\n", + "Ablated Loss: 2.806\n" + ] + } + ], + "source": [ + "layer_to_ablate = 0\n", + "head_index_to_ablate = 31\n", + "\n", + "# We define a head ablation hook\n", + "# The type annotations are NOT necessary, they're just a useful guide to the reader\n", + "#\n", + "def head_ablation_hook(\n", + " value: Float[torch.Tensor, \"batch pos head_index d_head\"],\n", + " hook: HookPoint\n", + ") -> Float[torch.Tensor, \"batch pos head_index d_head\"]:\n", + " print(f\"Shape of the value tensor: {value.shape}\")\n", + " value[:, :, head_index_to_ablate, :] = 0.\n", + " return value\n", + "\n", + "original_loss = model(llama_tokens, return_type=\"loss\")\n", + "ablated_loss = model.run_with_hooks(\n", + " llama_tokens,\n", + " return_type=\"loss\",\n", + " fwd_hooks=[(\n", + " utils.get_act_name(\"v\", layer_to_ablate),\n", + " head_ablation_hook\n", + " )]\n", + " )\n", + "print(f\"Original Loss: {original_loss.item():.3f}\")\n", + "print(f\"Ablated Loss: {ablated_loss.item():.3f}\")" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + }, + "vscode": { + "interpreter": { + "hash": "f03ec946e3b5caa7cc710a963f479e62a68fff56c790a7066e03c8b5c22adad9" + } + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "042a11af13fe4a1ab7c2a3beef4ee1ff": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "056409f93e7f48f9a9b8556ed27fcd08": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "0604d0239806455889c09137ebee2815": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "07b6cec0de3040efa472867d07cd0495": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_162b1719fad54fcabe0c9b0646956985", + "max": 499723, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_e2a836342061410c86827b364d3feb29", + "value": 499723 + } + }, + "09d337ea6f0945c79bad129283c85f0b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "0b83c8135ef44951964738be07783832": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2255bd5f2c6841bd9af4bc8963945936", + "placeholder": "​", + "style": "IPY_MODEL_687a8ab9cfcc45619d453fef2c8d04ea", + "value": " 2/2 [01:25<00:00, 38.63s/it]" + } + }, + "0e3321d5dcc74ce09baeadce8bd5d6e1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e6b8c1942a54415f998e1f44e796ce1b", + "max": 26788, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_2c9f2db0723d49fbb902d658703cdcb6", + "value": 26788 + } + }, + "0eaef657be2742b1867943854370d1ed": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_143ecc2bee0c42489185fbbe8c785578", + "placeholder": "​", + "style": "IPY_MODEL_fb16992539ed404383974588a5c9263a", + "value": " 2/2 [02:09<00:00, 59.74s/it]" + } + }, + "0f43034866264e019a364912d63ec11e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_7f0d44e575034332bc737cf702a79468", + "IPY_MODEL_201941842a4044148b6fc36e18fe3187", + "IPY_MODEL_0b83c8135ef44951964738be07783832" + ], + "layout": "IPY_MODEL_d63f72800b854c7c980ed08f77dd61bf" + } + }, + "1174909ea64c4afa9ba244564eeafaf5": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "143ecc2bee0c42489185fbbe8c785578": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "162b1719fad54fcabe0c9b0646956985": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1720ff23ff344e498c6f80a98bbd2534": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "186ff73f09504161b24be544f4707858": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e850e6dcd47f4da2842e5e2b8c8a43c5", + "placeholder": "​", + "style": "IPY_MODEL_744091704d0347d6ad0a6032ee362257", + "value": "model-00001-of-00002.safetensors: 100%" + } + }, + "18f4b44f74ff4b3c9198a1ecc8d672a8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "19f9d3421cef4c38abb500910e251f22": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c92a8c987de548239ec78c93ee5bf660", + "max": 2, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_1dc24cb487dc4b919bbd3c7e6a89e150", + "value": 2 + } + }, + "1b918c0947a54ae89af225ec372351bf": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_49502d07f7b143e784f4bf76ed6cc272", + "IPY_MODEL_9a3264f31b044a73be31ca4fba3e7f1d", + "IPY_MODEL_0eaef657be2742b1867943854370d1ed" + ], + "layout": "IPY_MODEL_e31ca54d82f64114ab694ce50ba62987" + } + }, + "1be40306cb6844319d44ba1f4f0164cb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1c861f3ec5214c148a714f04bdd9696c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1d118df455dd4d6aa395df422c711573": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1d2bfdbba1dd4e3db05a1a0ce052fad7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "1dc24cb487dc4b919bbd3c7e6a89e150": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "201941842a4044148b6fc36e18fe3187": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_236fbd6f27c14dcba6ec274fe34852ca", + "max": 2, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_96bc194611d2414c979340ad2b6416bf", + "value": 2 + } + }, + "21f81e0c00e94b5d82039aca95e44bfb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ce270b54645e4f1d87a0a10416739c1f", + "placeholder": "​", + "style": "IPY_MODEL_056409f93e7f48f9a9b8556ed27fcd08", + "value": " 1.62k/1.62k [00:00<00:00, 98.1kB/s]" + } + }, + "2255bd5f2c6841bd9af4bc8963945936": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2292ea67682e44bc8c61ce31bd18371c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9d350966e16c409eb32e0642ee908f24", + "placeholder": "​", + "style": "IPY_MODEL_cce9b0e3791e442a9e7b9b87a3ffec41", + "value": " 500k/500k [00:00<00:00, 31.8MB/s]" + } + }, + "22c5623f724647689c50a0e7c37cb371": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_fceb5194627442d79ee5ecd91ee01a10", + "IPY_MODEL_76516b2b4e104548a4b892f3979f41c8", + "IPY_MODEL_54ffa55dfd154b8a83edf45c157bec11" + ], + "layout": "IPY_MODEL_042a11af13fe4a1ab7c2a3beef4ee1ff" + } + }, + "236fbd6f27c14dcba6ec274fe34852ca": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "24670f0283144c7f97d2ba4733bda25d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_843f9fad9fea4db88e4201c4aad858b9", + "max": 9976576152, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_c0173c0a1f334f6fb3c1e8e0163da260", + "value": 9976576152 + } + }, + "251d31dade3348acaec1eb247ebb33ab": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "2582ee70618844ff8168b8637ddbbae4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "27ae531ccf1848b79a636a731bc635b9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_cc67c0d024c848898c5298691db43ed6", + "IPY_MODEL_0e3321d5dcc74ce09baeadce8bd5d6e1", + "IPY_MODEL_4b3dc4ad2d2f4ac59612b6392a6a77f7" + ], + "layout": "IPY_MODEL_cfb7c854917e4e3abca3158c0554c081" + } + }, + "2b903ea4e2934e4ca655296945bea0cd": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2c9f2db0723d49fbb902d658703cdcb6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "368b7e51e67749cc8952784c042e3410": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a8d929695aee466fbdc0324dab55aa15", + "max": 3500296424, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_d27922254bcb4c30aebff5da1ee34401", + "value": 3500296424 + } + }, + "37618acd36d44f68b0dbe07cedd60111": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3951d93995c345f3a001ec7fb9a63df1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "49502d07f7b143e784f4bf76ed6cc272": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_96e00b8ba2e640d79cd80a1c401dcedd", + "placeholder": "​", + "style": "IPY_MODEL_537df88aa48b4ae883ec9f1b86a82d11", + "value": "Downloading shards: 100%" + } + }, + "4abc540a7c404642b7a988937efdd196": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4b3dc4ad2d2f4ac59612b6392a6a77f7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_37618acd36d44f68b0dbe07cedd60111", + "placeholder": "​", + "style": "IPY_MODEL_a31c3597f0a546b0b47bbb873b7c3848", + "value": " 26.8k/26.8k [00:00<00:00, 1.55MB/s]" + } + }, + "4e17ea3af1b94b52944a816aba2525ba": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_186ff73f09504161b24be544f4707858", + "IPY_MODEL_24670f0283144c7f97d2ba4733bda25d", + "IPY_MODEL_70a8f83e1235458ab87e69a1c05780a4" + ], + "layout": "IPY_MODEL_c09f3c527dce4bed8f6d72922218317e" + } + }, + "4ff270b1c5dd484996ce971ed7014317": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_838dfdb7cec44c86bf41608c17f9852b", + "placeholder": "​", + "style": "IPY_MODEL_1720ff23ff344e498c6f80a98bbd2534", + "value": " 3.50G/3.50G [00:36<00:00, 132MB/s]" + } + }, + "51664a883bcb4d52b55e8826f17726bb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "537df88aa48b4ae883ec9f1b86a82d11": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "54ffa55dfd154b8a83edf45c157bec11": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_da72027fa75c4687bee46ab832297882", + "placeholder": "​", + "style": "IPY_MODEL_251d31dade3348acaec1eb247ebb33ab", + "value": " 414/414 [00:00<00:00, 30.9kB/s]" + } + }, + "55f5712956834686ba2b19a36438b726": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5f9f9e24d6834de3a058a63fecd3a19f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6485458203c946bba61f5ec96959abfb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6624e2d4d0a04e858c08a709f6bcf31d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bbad5c1e5c4d46a49e500d5b3890ad0b", + "placeholder": "​", + "style": "IPY_MODEL_cca945b5469b41549da4b2bd369c83ff", + "value": "generation_config.json: 100%" + } + }, + "687a8ab9cfcc45619d453fef2c8d04ea": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6c8c9155c1f04f0a97cd996e741446c0": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6f9ff8cbf6b3427ab3d30aa19b7703db": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "707287071f58451b8e5da6688cba286b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_4abc540a7c404642b7a988937efdd196", + "max": 1618, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_d98b1abde30a4dd28468c7c12d76c822", + "value": 1618 + } + }, + "70a8f83e1235458ab87e69a1c05780a4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f733175715f640668ae119c03542fcb5", + "placeholder": "​", + "style": "IPY_MODEL_09d337ea6f0945c79bad129283c85f0b", + "value": " 9.98G/9.98G [01:33<00:00, 167MB/s]" + } + }, + "71a49a876cc1413fad925c15672d5919": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "739dde14dc7e4e46aff9ee8858f0f334": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_6624e2d4d0a04e858c08a709f6bcf31d", + "IPY_MODEL_fdb572e99dde402199a43ac369891538", + "IPY_MODEL_cf227ad8761f4f80a7884ec915add0f4" + ], + "layout": "IPY_MODEL_3951d93995c345f3a001ec7fb9a63df1" + } + }, + "744091704d0347d6ad0a6032ee362257": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "76516b2b4e104548a4b892f3979f41c8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_96f61db2adf54c45bf60630919a67c95", + "max": 414, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_a055f35c82bb43cbb2ff3f8a98fa7a90", + "value": 414 + } + }, + "76e5f131438d4086bfcf52f7f4969b3a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_6485458203c946bba61f5ec96959abfb", + "placeholder": "​", + "style": "IPY_MODEL_c4bcdc27cad7466e8bf0f8655373a71b", + "value": "100%" + } + }, + "7b2d516d0c5046dc816801790aa4c2b5": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7f0d44e575034332bc737cf702a79468": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_18f4b44f74ff4b3c9198a1ecc8d672a8", + "placeholder": "​", + "style": "IPY_MODEL_1d2bfdbba1dd4e3db05a1a0ce052fad7", + "value": "Loading checkpoint shards: 100%" + } + }, + "838dfdb7cec44c86bf41608c17f9852b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "843f9fad9fea4db88e4201c4aad858b9": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8b28096911874bae9798969414e385de": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "92f27d77468c407aae1208893996a2a7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "951f41fc154c4f8aa38124c11e481223": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "96bc194611d2414c979340ad2b6416bf": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "96e00b8ba2e640d79cd80a1c401dcedd": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "96f61db2adf54c45bf60630919a67c95": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "99e864d7de4d461fbb8ccc3f4f765645": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_6c8c9155c1f04f0a97cd996e741446c0", + "placeholder": "​", + "style": "IPY_MODEL_fb3f514d931e4c8192f7cd3557ece592", + "value": "model-00002-of-00002.safetensors: 100%" + } + }, + "9a3264f31b044a73be31ca4fba3e7f1d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_55f5712956834686ba2b19a36438b726", + "max": 2, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_f4a21799b8f84e4687e36314cfc873f0", + "value": 2 + } + }, + "9a3e9cc371da4801ab642ea09101a40f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "9d0f7c1505c9436ca8c04724381dd70f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bb637c634d6f4efebafa3aa503c9862e", + "max": 1842767, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_8b28096911874bae9798969414e385de", + "value": 1842767 + } + }, + "9d350966e16c409eb32e0642ee908f24": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9e1c77e44d884d4d80684179c6f4c96c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a055f35c82bb43cbb2ff3f8a98fa7a90": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "a31c3597f0a546b0b47bbb873b7c3848": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a6182b17067c431fb933d36cef4d4438": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a76b6b9e474a4cf79c98b52758adeab5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a8d929695aee466fbdc0324dab55aa15": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "aec34d7e6c4444778daa42043fb595f2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_99e864d7de4d461fbb8ccc3f4f765645", + "IPY_MODEL_368b7e51e67749cc8952784c042e3410", + "IPY_MODEL_4ff270b1c5dd484996ce971ed7014317" + ], + "layout": "IPY_MODEL_5f9f9e24d6834de3a058a63fecd3a19f" + } + }, + "b525a8f8d07b4abcb162fcca0bfab28d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e86dea79f86143cab0b88c2d5f8992e8", + "IPY_MODEL_9d0f7c1505c9436ca8c04724381dd70f", + "IPY_MODEL_dd5b7da397c849adabcf33c2b7c3aeb8" + ], + "layout": "IPY_MODEL_1be40306cb6844319d44ba1f4f0164cb" + } + }, + "b7daa28812e54fd6941f33d1b8325666": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_f660c1ba3a664866a5abcdd3c35d0e72", + "IPY_MODEL_707287071f58451b8e5da6688cba286b", + "IPY_MODEL_21f81e0c00e94b5d82039aca95e44bfb" + ], + "layout": "IPY_MODEL_71a49a876cc1413fad925c15672d5919" + } + }, + "bb637c634d6f4efebafa3aa503c9862e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bbad5c1e5c4d46a49e500d5b3890ad0b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bef5c99c8ba3495b85454d8f10ff17bd": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "c0173c0a1f334f6fb3c1e8e0163da260": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "c09f3c527dce4bed8f6d72922218317e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c3520f624fd043bfaac26503ff10f254": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1c861f3ec5214c148a714f04bdd9696c", + "placeholder": "​", + "style": "IPY_MODEL_92f27d77468c407aae1208893996a2a7", + "value": " 2/2 [00:02<00:00, 1.09s/it]" + } + }, + "c4bcdc27cad7466e8bf0f8655373a71b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "c92a8c987de548239ec78c93ee5bf660": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "cc67c0d024c848898c5298691db43ed6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_fc7ac85b780a4605809bf68f006e1ef3", + "placeholder": "​", + "style": "IPY_MODEL_2582ee70618844ff8168b8637ddbbae4", + "value": "model.safetensors.index.json: 100%" + } + }, + "cca945b5469b41549da4b2bd369c83ff": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "cce9b0e3791e442a9e7b9b87a3ffec41": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "ce270b54645e4f1d87a0a10416739c1f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "cf227ad8761f4f80a7884ec915add0f4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_0604d0239806455889c09137ebee2815", + "placeholder": "​", + "style": "IPY_MODEL_9a3e9cc371da4801ab642ea09101a40f", + "value": " 188/188 [00:00<00:00, 13.3kB/s]" + } + }, + "cfb7c854917e4e3abca3158c0554c081": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d27922254bcb4c30aebff5da1ee34401": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "d5bb62948e5545319030671e17be5ced": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e354464414fd479388da58d033700024", + "IPY_MODEL_07b6cec0de3040efa472867d07cd0495", + "IPY_MODEL_2292ea67682e44bc8c61ce31bd18371c" + ], + "layout": "IPY_MODEL_a6182b17067c431fb933d36cef4d4438" + } + }, + "d63f72800b854c7c980ed08f77dd61bf": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d6603805bb654ea9a7e0ae2b60dd5be6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_76e5f131438d4086bfcf52f7f4969b3a", + "IPY_MODEL_19f9d3421cef4c38abb500910e251f22", + "IPY_MODEL_c3520f624fd043bfaac26503ff10f254" + ], + "layout": "IPY_MODEL_f4e419873d6f49a99b8718666811073e" + } + }, + "d98b1abde30a4dd28468c7c12d76c822": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "da72027fa75c4687bee46ab832297882": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "dd5b7da397c849adabcf33c2b7c3aeb8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1d118df455dd4d6aa395df422c711573", + "placeholder": "​", + "style": "IPY_MODEL_e990b322c8104b329227207e442003ee", + "value": " 1.84M/1.84M [00:00<00:00, 25.3MB/s]" + } + }, + "e2a836342061410c86827b364d3feb29": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "e31ca54d82f64114ab694ce50ba62987": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e354464414fd479388da58d033700024": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_eb69099d7fd9416881bdc9635838da1b", + "placeholder": "​", + "style": "IPY_MODEL_6f9ff8cbf6b3427ab3d30aa19b7703db", + "value": "tokenizer.model: 100%" + } + }, + "e6b8c1942a54415f998e1f44e796ce1b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e850e6dcd47f4da2842e5e2b8c8a43c5": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e86dea79f86143cab0b88c2d5f8992e8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7b2d516d0c5046dc816801790aa4c2b5", + "placeholder": "​", + "style": "IPY_MODEL_9e1c77e44d884d4d80684179c6f4c96c", + "value": "tokenizer.json: 100%" + } + }, + "e990b322c8104b329227207e442003ee": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "eb69099d7fd9416881bdc9635838da1b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f4a21799b8f84e4687e36314cfc873f0": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "f4e419873d6f49a99b8718666811073e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f660c1ba3a664866a5abcdd3c35d0e72": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1174909ea64c4afa9ba244564eeafaf5", + "placeholder": "​", + "style": "IPY_MODEL_a76b6b9e474a4cf79c98b52758adeab5", + "value": "tokenizer_config.json: 100%" + } + }, + "f733175715f640668ae119c03542fcb5": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fb16992539ed404383974588a5c9263a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "fb3f514d931e4c8192f7cd3557ece592": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "fc7ac85b780a4605809bf68f006e1ef3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fceb5194627442d79ee5ecd91ee01a10": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_51664a883bcb4d52b55e8826f17726bb", + "placeholder": "​", + "style": "IPY_MODEL_951f41fc154c4f8aa38124c11e481223", + "value": "special_tokens_map.json: 100%" + } + }, + "fdb572e99dde402199a43ac369891538": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2b903ea4e2934e4ca655296945bea0cd", + "max": 188, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_bef5c99c8ba3495b85454d8f10ff17bd", + "value": 188 + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 104f4feae..baf32ad05 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -20,6 +20,7 @@ import tqdm.auto as tqdm from fancy_einsum import einsum from jaxtyping import Float, Int +from packaging import version from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase from typing_extensions import Literal @@ -1188,11 +1189,32 @@ def from_pretrained( default_padding_side: Which side to pad on when tokenizing. Defaults to "right". """ + assert not ( from_pretrained_kwargs.get("load_in_8bit", False) or from_pretrained_kwargs.get("load_in_4bit", False) ), "Quantization not supported" + if hf_model is not None: + hf_cfg = hf_model.config.to_dict() + qc = hf_cfg.get("quantization_config", {}) + load_in_4bit = qc.get("load_in_4bit", False) + load_in_8bit = qc.get("load_in_8bit", False) + quant_method = qc.get("quant_method", "") + assert not load_in_8bit, "8-bit quantization is not supported" + assert not ( + load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1")) + ), "Quantization is only supported for torch versions >= 2.1.1" + assert not ( + load_in_4bit and ("llama" not in model_name.lower()) + ), "Quantization is only supported for Llama models" + if load_in_4bit: + assert ( + qc.get("quant_method", "") == "bitsandbytes" + ), "Only bitsandbytes quantization is supported" + else: + hf_cfg = {} + if isinstance(dtype, str): # Convert from string to a torch dtype dtype = DTYPE_FROM_STRING[dtype] @@ -1215,6 +1237,7 @@ def from_pretrained( # checkpoint cfg = loading.get_pretrained_model_config( official_model_name, + hf_cfg=hf_cfg, checkpoint_index=checkpoint_index, checkpoint_value=checkpoint_value, fold_ln=fold_ln, @@ -1519,7 +1542,12 @@ def load_and_process_state_dict( if refactor_factored_attn_matrices: state_dict = self.refactor_factored_attn_matrices(state_dict) - self.load_state_dict(state_dict, strict=False) + if self.cfg.load_in_4bit: + # with quantization, parameters should be assigned + # so that quantization settings are not lost + self.load_state_dict(state_dict, assign=True, strict=False) + else: + self.load_state_dict(state_dict, strict=False) def fill_missing_keys(self, state_dict): return loading.fill_missing_keys(self, state_dict) diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index d41079bca..7a38c22c5 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -149,6 +149,8 @@ class HookedTransformerConfig: tokenizer_prepends_bos (bool, *optional*): This flag is set by set_tokenizer. It is set to True only when the tokenizer automatically prepends the BOS token if initialized with add_bos_token=True. We need this information to dynamically control bos prepending. + load_in_4bit(bool): If this flag is set, then it's assumed that parameters are 4-bit quantized + with bitsandbytes. Currently only supported for Llama. n_key_value_heads (int, *optional*): The number of groups of heads that use the same key and value matrix. Only for models that use Grouped Query Attention. post_embedding_ln (bool): Whether to apply layer normalization after embedding the tokens. Defaults @@ -209,6 +211,7 @@ class HookedTransformerConfig: rotary_base: int = 10000 trust_remote_code: bool = False rotary_adjacent_pairs: bool = False + load_in_4bit: bool = False num_experts: Optional[int] = None experts_per_token: Optional[int] = None diff --git a/transformer_lens/components.py b/transformer_lens/components.py index b419540b7..66f6ff9d9 100644 --- a/transformer_lens/components.py +++ b/transformer_lens/components.py @@ -17,6 +17,7 @@ from better_abc import abstract_attribute from fancy_einsum import einsum from jaxtyping import Float, Int +from transformers.utils import is_bitsandbytes_available from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookPoint @@ -30,6 +31,10 @@ solu, ) +if is_bitsandbytes_available(): + import bitsandbytes as bnb + from bitsandbytes.nn.modules import Params4bit + # Embed & Unembed class Embed(nn.Module): @@ -398,14 +403,21 @@ def __init__( if isinstance(cfg, Dict): cfg = HookedTransformerConfig.from_dict(cfg) self.cfg = cfg - self.W_Q = nn.Parameter( - torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype) - ) + + if self.cfg.load_in_4bit: + nq = int((cfg.d_model * cfg.d_model) / 2) + self.W_Q = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) + self.W_O = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) + else: + self.W_Q = nn.Parameter( + torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype) + ) + self.W_O = nn.Parameter( + torch.empty(self.cfg.n_heads, self.cfg.d_head, self.cfg.d_model, dtype=cfg.dtype) + ) self.W_K = abstract_attribute() self.W_V = abstract_attribute() - self.W_O = nn.Parameter( - torch.empty(self.cfg.n_heads, self.cfg.d_head, self.cfg.d_model, dtype=cfg.dtype) - ) + self.b_Q = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)) self.b_K = abstract_attribute() self.b_V = abstract_attribute() @@ -574,30 +586,51 @@ def forward( pattern = pattern.to(v.device) z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] if not self.cfg.use_attn_result: - out = ( - ( + if self.cfg.load_in_4bit: + # call bitsandbytes method to dequantize and multiply + out = bnb.matmul_4bit( + z.reshape(z.shape[0], z.shape[1], self.cfg.d_model), + self.W_O.t(), + # bias=self.W_O.t(), + bias=None, + quant_state=self.W_O.quant_state, + ) + +self.b_O + else: + out = ( + ( + einsum( + "batch pos head_index d_head, \ + head_index d_head d_model -> \ + batch pos d_model", + z, + self.W_O, + ) + ) + + self.b_O + ) # [batch, pos, d_model] + else: + # Explicitly calculate the attention result so it can be accessed by a hook + # This is off by default because it can easily eat through your GPU memory. + if self.cfg.load_in_4bit: + result = self.hook_result( + bnb.matmul_4bit( + z.reshape(z.shape[0], z.shape[1], self.cfg.d_model), + self.W_O.t(), + bias=None, + quant_state=self.W_O.quant_state, + ) + ) + else: + result = self.hook_result( einsum( "batch pos head_index d_head, \ head_index d_head d_model -> \ - batch pos d_model", + batch pos head_index d_model", z, self.W_O, ) - ) - + self.b_O - ) # [batch, pos, d_model] - else: - # Explicitly calculate the attention result so it can be accessed by a hook - # This is off by default because it can easily eat through your GPU memory. - result = self.hook_result( - einsum( - "batch pos head_index d_head, \ - head_index d_head d_model -> \ - batch pos head_index d_model", - z, - self.W_O, - ) - ) # [batch, pos, head_index, d_model] + ) # [batch, pos, head_index, d_model] out = ( einops.reduce(result, "batch position index model->batch position model", "sum") + self.b_O @@ -628,33 +661,82 @@ def calculate_qkv_matrices( else: qkv_einops_string = "batch pos d_model" - q = self.hook_q( - einsum( - f"{qkv_einops_string}, head_index d_model d_head \ - -> batch pos head_index d_head", - query_input, - self.W_Q, + if self.cfg.load_in_4bit: + q = self.hook_q( + # call bitsandbytes method to dequantize and multiply + bnb.matmul_4bit( + query_input, + self.W_Q.t(), + bias=None, + quant_state=self.W_Q.quant_state, + ).reshape( + query_input.shape[0], + query_input.shape[1], + self.cfg.n_heads, + self.cfg.d_head, + ) + + self.b_Q ) - + self.b_Q - ) # [batch, pos, head_index, d_head] - k = self.hook_k( - einsum( - f"{qkv_einops_string}, head_index d_model d_head \ - -> batch pos head_index d_head", - key_input, - self.W_K, + else: + q = self.hook_q( + einsum( + f"{qkv_einops_string}, head_index d_model d_head \ + -> batch pos head_index d_head", + query_input, + self.W_Q, + ) + + self.b_Q + ) # [batch, pos, head_index, d_head] + if self.cfg.load_in_4bit: + k = self.hook_k( + # call bitsandbytes method to dequantize and multiply + bnb.matmul_4bit( + key_input, self.W_K.t(), bias=None, quant_state=self.W_K.quant_state + ).reshape( + key_input.shape[0], + key_input.shape[1], + self.cfg.n_heads, + self.cfg.d_head, + ) + + self.b_K ) - + self.b_K - ) # [batch, pos, head_index, d_head] - v = self.hook_v( - einsum( - f"{qkv_einops_string}, head_index d_model d_head \ - -> batch pos head_index d_head", - value_input, - self.W_V, + else: + k = self.hook_k( + einsum( + f"{qkv_einops_string}, head_index d_model d_head \ + -> batch pos head_index d_head", + key_input, + self.W_K, + ) + + self.b_K + ) # [batch, pos, head_index, d_head] + + if self.cfg.load_in_4bit: + v = self.hook_v( + # call bitsandbytes method to dequantize and multiply + bnb.matmul_4bit( + value_input, + self.W_V.t(), + bias=None, + quant_state=self.W_V.quant_state, + ).reshape( + value_input.shape[0], + value_input.shape[1], + self.cfg.n_heads, + self.cfg.d_head, + ) + + self.b_V ) - + self.b_V - ) # [batch, pos, head_index, d_head] + else: + v = self.hook_v( + einsum( + f"{qkv_einops_string}, head_index d_model d_head \ + -> batch pos head_index d_head", + value_input, + self.W_V, + ) + + self.b_V + ) # [batch, pos, head_index, d_head] return q, k, v def calculate_attention_scores( @@ -944,12 +1026,21 @@ def __init__( if isinstance(cfg, Dict): cfg = HookedTransformerConfig.from_dict(cfg) self.cfg = cfg - self.W_K = nn.Parameter( - torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype) - ) - self.W_V = nn.Parameter( - torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype) - ) + + if cfg.load_in_4bit: + # 4-bit quantization convention + nq = int((cfg.d_model * cfg.d_model) / 2) + self.W_K = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) + self.W_V = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) + else: + self.W_K = nn.Parameter( + torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype) + ) + self.W_V = nn.Parameter( + torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype) + ) + self.b_K = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)) + self.b_V = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)) self.b_K = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)) self.b_V = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)) @@ -1220,10 +1311,22 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): cfg = HookedTransformerConfig.from_dict(cfg) self.cfg = cfg assert self.cfg.d_mlp is not None # keep mypy happy - self.W_in = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype)) - self.W_gate = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype)) + + if cfg.load_in_4bit: + nq = int((self.cfg.d_model * self.cfg.d_mlp) / 2) + self.W_in = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) + self.W_gate = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) + self.W_out = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) + else: + self.W_in = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype)) + self.W_gate = nn.Parameter( + torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype) + ) + self.W_out = nn.Parameter( + torch.empty(self.cfg.d_mlp, self.cfg.d_model, dtype=cfg.dtype) + ) + self.b_in = nn.Parameter(torch.zeros(self.cfg.d_mlp, dtype=cfg.dtype)) - self.W_out = nn.Parameter(torch.empty(self.cfg.d_mlp, self.cfg.d_model, dtype=cfg.dtype)) self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype)) # hook on gate output but before act_fn @@ -1259,27 +1362,53 @@ def forward( self, x: Float[torch.Tensor, "batch pos d_model"] ) -> Float[torch.Tensor, "batch pos d_model"]: # Technically, all these einsums could be done with a single matmul, but this is more readable. - pre_act = self.hook_pre( - einsum("batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_gate) - ) # [batch, pos, d_mlp] - if self.cfg.act_fn is not None and not self.cfg.act_fn.endswith("_ln"): - pre_linear = self.hook_pre_linear( - einsum("batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in) + if self.cfg.load_in_4bit: + pre_act = self.hook_pre( + bnb.matmul_4bit(x, self.W_gate.t(), bias=None, quant_state=self.W_gate.quant_state) ) + else: + pre_act = self.hook_pre( + einsum( + "batch pos d_model, d_model d_mlp -> batch pos d_mlp", + x, + self.W_gate, + ) + ) # [batch, pos, d_mlp] + + if self.cfg.act_fn is not None and not self.cfg.act_fn.endswith("_ln"): + if self.cfg.load_in_4bit: + pre_linear = self.hook_pre_linear( + bnb.matmul_4bit(x, self.W_in.t(), bias=None, quant_state=self.W_in.quant_state) + ) + else: + pre_linear = self.hook_pre_linear( + einsum( + "batch pos d_model, d_model d_mlp -> batch pos d_mlp", + x, + self.W_in, + ) + ) + post_act = self.hook_post( (self.act_fn(pre_act) * pre_linear) + self.b_in ) # [batch, pos, d_mlp] else: mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] post_act = self.hook_post(self.ln(mid_act)) - return ( - einsum( - "batch pos d_mlp, d_mlp d_model -> batch pos d_model", - post_act, - self.W_out, + + if self.cfg.load_in_4bit: + return bnb.matmul_4bit( + post_act, self.W_out.t(), bias=None, quant_state=self.W_out.quant_state + ) + else: + return ( + einsum( + "batch pos d_mlp, d_mlp d_model -> batch pos d_model", + post_act, + self.W_out, + ) + + self.b_out ) - + self.b_out - ) class MoE(nn.Module): diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 08ff472fe..0212ae89c 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1191,6 +1191,7 @@ def convert_neel_model_config(official_model_name: str, **kwargs): def get_pretrained_model_config( model_name: str, + hf_cfg: Optional[dict] = None, checkpoint_index: Optional[int] = None, checkpoint_value: Optional[int] = None, fold_ln: bool = False, @@ -1210,6 +1211,8 @@ def get_pretrained_model_config( model_name: The name of the model. This can be either the official HuggingFace model name, or the name of a model trained by me (NeelNanda). + hf_cfg (dict, optional): Config of a loaded pretrained HF model, + converted to a dictionary. checkpoint_index (int, optional): If loading from a checkpoint, the index of the checkpoint to load. Defaults to None. checkpoint_value (int, optional): If loading from a checkpoint, the @@ -1301,6 +1304,8 @@ def get_pretrained_model_config( cfg_dict["device"] = device cfg_dict["n_devices"] = n_devices cfg_dict["default_prepend_bos"] = default_prepend_bos + if hf_cfg is not None: + cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False) cfg = HookedTransformerConfig.from_dict(cfg_dict) return cfg @@ -1764,6 +1769,14 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig): W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) W_K = einops.rearrange(W_K, "(n h) m->n m h", n=n_kv_heads) W_V = einops.rearrange(W_V, "(n h) m->n m h", n=n_kv_heads) + + # in case of quantization, + # parameters should stay as bitsandbytes.nn.modules.Params4bit + if not cfg.load_in_4bit: + W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_heads) + W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V @@ -1785,7 +1798,10 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig): ) W_O = llama.model.layers[l].self_attn.o_proj.weight - W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + + if not cfg.load_in_4bit: + W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_O"] = W_O.to(device=cfg.device) state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros( @@ -1794,13 +1810,20 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.ln2.w"] = llama.model.layers[l].post_attention_layernorm.weight - state_dict[f"blocks.{l}.mlp.W_in"] = llama.model.layers[l].mlp.up_proj.weight.T - state_dict[f"blocks.{l}.mlp.W_gate"] = llama.model.layers[l].mlp.gate_proj.weight.T + # in case of quantization, + # parameters should stay as bitsandbytes.nn.modules.Params4bit + if not cfg.load_in_4bit: + state_dict[f"blocks.{l}.mlp.W_in"] = llama.model.layers[l].mlp.up_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_gate"] = llama.model.layers[l].mlp.gate_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_out"] = llama.model.layers[l].mlp.down_proj.weight.T + else: + state_dict[f"blocks.{l}.mlp.W_in"] = llama.model.layers[l].mlp.up_proj.weight + state_dict[f"blocks.{l}.mlp.W_gate"] = llama.model.layers[l].mlp.gate_proj.weight + state_dict[f"blocks.{l}.mlp.W_out"] = llama.model.layers[l].mlp.down_proj.weight + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros( cfg.d_mlp, dtype=cfg.dtype, device=cfg.device ) - - state_dict[f"blocks.{l}.mlp.W_out"] = llama.model.layers[l].mlp.down_proj.weight.T state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros( cfg.d_model, dtype=cfg.dtype, device=cfg.device ) From 6cd64d5c8728ccaf7d10e836f882f66e368ae904 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Thu, 25 Apr 2024 22:34:58 +0200 Subject: [PATCH 64/73] removed deuplicate rearrange block (#555) * removed deuplicate rearrange block * removed duplicate variables * fixed param name --- transformer_lens/components.py | 2 -- transformer_lens/loading_from_pretrained.py | 7 ++----- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/transformer_lens/components.py b/transformer_lens/components.py index 66f6ff9d9..a5beeee7b 100644 --- a/transformer_lens/components.py +++ b/transformer_lens/components.py @@ -1041,8 +1041,6 @@ def __init__( ) self.b_K = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)) self.b_V = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)) - self.b_K = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)) - self.b_V = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)) class GroupedQueryAttention(AbstractAttention): diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 0212ae89c..0a8d132cc 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1766,16 +1766,13 @@ def convert_llama_weights(llama, cfg: HookedTransformerConfig): W_Q = llama.model.layers[l].self_attn.q_proj.weight W_K = llama.model.layers[l].self_attn.k_proj.weight W_V = llama.model.layers[l].self_attn.v_proj.weight - W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) - W_K = einops.rearrange(W_K, "(n h) m->n m h", n=n_kv_heads) - W_V = einops.rearrange(W_V, "(n h) m->n m h", n=n_kv_heads) # in case of quantization, # parameters should stay as bitsandbytes.nn.modules.Params4bit if not cfg.load_in_4bit: W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) - W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_heads) - W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "(n h) m->n m h", n=n_kv_heads) + W_V = einops.rearrange(W_V, "(n h) m->n m h", n=n_kv_heads) state_dict[f"blocks.{l}.attn.W_Q"] = W_Q state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K From 1139cafe343dee7dc5bed9b9c127826e68c8263e Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Sat, 27 Apr 2024 01:31:59 +0200 Subject: [PATCH 65/73] Bert demo ci (#556) * revised demo testing to check all demos * separated demos * changed demo test order * rearranged test order * updated attribution patching to run differnt code in github * rearranged tests * updated header * updated grokking demo * updated bert for testing * updated bert demo * ran cells * removed github check * removed cells to skip * ignored output of loading cells * removed other tests --- demos/Attribution_Patching_Demo.ipynb | 2 +- demos/BERT.ipynb | 77 ++++++++++++++++++--------- demos/Grokking_Demo.ipynb | 44 ++++++++------- demos/Main_Demo.ipynb | 3 +- makefile | 3 +- 5 files changed, 83 insertions(+), 46 deletions(-) diff --git a/demos/Attribution_Patching_Demo.ipynb b/demos/Attribution_Patching_Demo.ipynb index cef67eb8b..8d8796629 100644 --- a/demos/Attribution_Patching_Demo.ipynb +++ b/demos/Attribution_Patching_Demo.ipynb @@ -1 +1 @@ -{"cells":[{"cell_type":"markdown","metadata":{},"source":["\n"," \"Open\n",""]},{"cell_type":"markdown","metadata":{},"source":[" # Attribution Patching Demo\n"," **Read [the accompanying blog post here](https://neelnanda.io/attribution-patching) for more context**\n"," This is an interim research report, giving a whirlwind tour of some unpublished work I did at Anthropic (credit to the then team - Chris Olah, Catherine Olsson, Nelson Elhage and Tristan Hume for help, support, and mentorship!)\n","\n"," The goal of this work is run activation patching at an industrial scale, by using gradient based attribution to approximate the technique - allow an arbitrary number of patches to be made on two forwards and a single backward pass\n","\n"," I have had less time than hoped to flesh out this investigation, but am writing up a rough investigation and comparison to standard activation patching on a few tasks to give a sense of the potential of this approach, and where it works vs falls down."]},{"cell_type":"markdown","metadata":{},"source":[" To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n","\n"," **Tips for reading this Colab:**\n"," * You can run all this code for yourself!\n"," * The graphs are interactive!\n"," * Use the table of contents pane in the sidebar to navigate\n"," * Collapse irrelevant sections with the dropdown arrows\n"," * Search the page using the search in the sidebar, not CTRL+F"]},{"cell_type":"markdown","metadata":{},"source":[" ## Setup (Ignore)"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Running as a Jupyter notebook - intended for development only!\n"]}],"source":["# Janky code to do different setup when run in a Colab notebook vs VSCode\n","DEBUG_MODE = False\n","try:\n"," import google.colab\n"," IN_COLAB = True\n"," print(\"Running as a Colab notebook\")\n"," %pip install transformer_lens==1.1.1\n"," %pip install torchtyping\n"," # Install my janky personal plotting utils\n"," %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n"," # Install another version of node that makes PySvelte work way faster\n"," !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n"," %pip install git+https://github.com/neelnanda-io/PySvelte.git\n"," # Needed for PySvelte to work, v3 came out and broke things...\n"," %pip install typeguard==2.13.3\n","except:\n"," IN_COLAB = False\n"," print(\"Running as a Jupyter notebook - intended for development only!\")\n"," from IPython import get_ipython\n","\n"," ipython = get_ipython()\n"," # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n"," ipython.magic(\"load_ext autoreload\")\n"," ipython.magic(\"autoreload 2\")"]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[],"source":["# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n","import plotly.io as pio\n","\n","if IN_COLAB or not DEBUG_MODE:\n"," # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n"," pio.renderers.default = \"colab\"\n","else:\n"," pio.renderers.default = \"notebook_connected\""]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[],"source":["# Import stuff\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","import numpy as np\n","import einops\n","from fancy_einsum import einsum\n","import tqdm.notebook as tqdm\n","import random\n","from pathlib import Path\n","import plotly.express as px\n","from torch.utils.data import DataLoader\n","\n","from torchtyping import TensorType as TT\n","from typing import List, Union, Optional, Callable\n","from functools import partial\n","import copy\n","import itertools\n","import json\n","\n","from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n","import dataclasses\n","import datasets\n","from IPython.display import HTML, Markdown"]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[],"source":["import pysvelte\n","\n","import transformer_lens\n","import transformer_lens.utils as utils\n","from transformer_lens.hook_points import (\n"," HookedRootModule,\n"," HookPoint,\n",") # Hooking utilities\n","from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache"]},{"cell_type":"markdown","metadata":{},"source":[" Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[],"source":["from neel_plotly import line, imshow, scatter"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":["import transformer_lens.patching as patching"]},{"cell_type":"markdown","metadata":{},"source":[" ## IOI Patching Setup\n"," This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important."]},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-small into HookedTransformer\n"]}],"source":["model = HookedTransformer.from_pretrained(\"gpt2-small\")\n","model.set_use_attn_result(True)"]},{"cell_type":"code","execution_count":9,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to\n","Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to\n","Answer token indices tensor([[ 5335, 1757],\n"," [ 1757, 5335],\n"," [ 4186, 3700],\n"," [ 3700, 4186],\n"," [ 6035, 15686],\n"," [15686, 6035],\n"," [ 5780, 14235],\n"," [14235, 5780]], device='cuda:0')\n"]}],"source":["prompts = ['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']\n","answers = [(' Mary', ' John'), (' John', ' Mary'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]\n","\n","clean_tokens = model.to_tokens(prompts)\n","# Swap each adjacent pair, with a hacky list comprehension\n","corrupted_tokens = clean_tokens[\n"," [(i+1 if i%2==0 else i-1) for i in range(len(clean_tokens)) ]\n"," ]\n","print(\"Clean string 0\", model.to_string(clean_tokens[0]))\n","print(\"Corrupted string 0\", model.to_string(corrupted_tokens[0]))\n","\n","answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=model.cfg.device)\n","print(\"Answer token indices\", answer_token_indices)"]},{"cell_type":"code","execution_count":10,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 3.5519\n","Corrupted logit diff: -3.5519\n"]}],"source":["def get_logit_diff(logits, answer_token_indices=answer_token_indices):\n"," if len(logits.shape)==3:\n"," # Get final logits only\n"," logits = logits[:, -1, :]\n"," correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))\n"," incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))\n"," return (correct_logits - incorrect_logits).mean()\n","\n","clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n","corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n","\n","clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()\n","print(f\"Clean logit diff: {clean_logit_diff:.4f}\")\n","\n","corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()\n","print(f\"Corrupted logit diff: {corrupted_logit_diff:.4f}\")"]},{"cell_type":"code","execution_count":11,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Baseline is 1: 1.0000\n","Corrupted Baseline is 0: 0.0000\n"]}],"source":["CLEAN_BASELINE = clean_logit_diff\n","CORRUPTED_BASELINE = corrupted_logit_diff\n","def ioi_metric(logits, answer_token_indices=answer_token_indices):\n"," return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE - CORRUPTED_BASELINE)\n","\n","print(f\"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}\")\n","print(f\"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Patching\n"," In the following cells, we define attribution patching and use it in various ways on the model."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["Metric = Callable[[TT[\"batch_and_pos_dims\", \"d_model\"]], float]"]},{"cell_type":"code","execution_count":13,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Value: 1.0\n","Clean Activations Cached: 220\n","Clean Gradients Cached: 220\n","Corrupted Value: 0.0\n","Corrupted Activations Cached: 220\n","Corrupted Gradients Cached: 220\n"]}],"source":["filter_not_qkv_input = lambda name: \"_input\" not in name\n","def get_cache_fwd_and_bwd(model, tokens, metric):\n"," model.reset_hooks()\n"," cache = {}\n"," def forward_cache_hook(act, hook):\n"," cache[hook.name] = act.detach()\n"," model.add_hook(filter_not_qkv_input, forward_cache_hook, \"fwd\")\n","\n"," grad_cache = {}\n"," def backward_cache_hook(act, hook):\n"," grad_cache[hook.name] = act.detach()\n"," model.add_hook(filter_not_qkv_input, backward_cache_hook, \"bwd\")\n","\n"," value = metric(model(tokens))\n"," value.backward()\n"," model.reset_hooks()\n"," return value.item(), ActivationCache(cache, model), ActivationCache(grad_cache, model)\n","\n","clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(model, clean_tokens, ioi_metric)\n","print(\"Clean Value:\", clean_value)\n","print(\"Clean Activations Cached:\", len(clean_cache))\n","print(\"Clean Gradients Cached:\", len(clean_grad_cache))\n","corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(model, corrupted_tokens, ioi_metric)\n","print(\"Corrupted Value:\", corrupted_value)\n","print(\"Corrupted Activations Cached:\", len(corrupted_cache))\n","print(\"Corrupted Gradients Cached:\", len(corrupted_grad_cache))"]},{"cell_type":"markdown","metadata":{},"source":[" ### Attention Attribution\n"," The easiest thing to start with is to not even engage with the corrupted tokens/patching, but to look at the attribution of the attention patterns - that is, the linear approximation to what happens if you set each element of the attention pattern to zero. This, as it turns out, is a good proxy to what is going on with each head!\n"," Note that this is *not* the same as what we will later do with patching. In particular, this does not set up a careful counterfactual! It's a good tool for what's generally going on in this problem, but does not control for eg stuff that systematically boosts John > Mary in general, stuff that says \"I should activate the IOI circuit\", etc. Though using logit diff as our metric *does*\n"," Each element of the batch is independent and the metric is an average logit diff, so we can analyse each batch element independently here. We'll look at the first one, and then at the average across the whole batch (note - 4 prompts have indirect object before subject, 4 prompts have it the other way round, making the average pattern harder to interpret - I plot it over the first sequence of tokens as a mildly misleading reference).\n"," We can compare it to the interpretability in the wild diagram, and basically instantly recover most of the circuit!"]},{"cell_type":"code","execution_count":14,"metadata":{},"outputs":[],"source":["def create_attention_attr(clean_cache, clean_grad_cache) -> TT[\"batch\", \"layer\", \"head_index\", \"dest\", \"src\"]:\n"," attention_stack = torch.stack([clean_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0)\n"," attention_grad_stack = torch.stack([clean_grad_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0)\n"," attention_attr = attention_grad_stack * attention_stack\n"," attention_attr = einops.rearrange(attention_attr, \"layer batch head_index dest src -> batch layer head_index dest src\")\n"," return attention_attr\n","\n","attention_attr = create_attention_attr(clean_cache, clean_grad_cache)"]},{"cell_type":"code","execution_count":15,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["['L0H0', 'L0H1', 'L0H2', 'L0H3', 'L0H4']\n","['L0H0+', 'L0H0-', 'L0H1+', 'L0H1-', 'L0H2+']\n","['L0H0Q', 'L0H0K', 'L0H0V', 'L0H1Q', 'L0H1K']\n"]}],"source":["HEAD_NAMES = [f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]\n","HEAD_NAMES_SIGNED = [f\"{name}{sign}\" for name in HEAD_NAMES for sign in [\"+\", \"-\"]]\n","HEAD_NAMES_QKV = [f\"{name}{act_name}\" for name in HEAD_NAMES for act_name in [\"Q\", \"K\", \"V\"]]\n","print(HEAD_NAMES[:5])\n","print(HEAD_NAMES_SIGNED[:5])\n","print(HEAD_NAMES_QKV[:5])"]},{"cell_type":"markdown","metadata":{},"source":[" An extremely janky way to plot the attention attribution patterns. We scale them to be in [-1, 1], split each head into a positive and negative part (so all of it is in [0, 1]), and then plot the top 20 head-halves (a head can appear twice!) by the max value of the attribution pattern."]},{"cell_type":"code","execution_count":16,"metadata":{},"outputs":[{"data":{"text/markdown":["### Attention Attribution for first sequence"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["### Summed Attention Attribution for all sequences"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\n"]}],"source":["def plot_attention_attr(attention_attr, tokens, top_k=20, index=0, title=\"\"):\n"," if len(tokens.shape)==2:\n"," tokens = tokens[index]\n"," if len(attention_attr.shape)==5:\n"," attention_attr = attention_attr[index]\n"," attention_attr_pos = attention_attr.clamp(min=-1e-5)\n"," attention_attr_neg = - attention_attr.clamp(max=1e-5)\n"," attention_attr_signed = torch.stack([attention_attr_pos, attention_attr_neg], dim=0)\n"," attention_attr_signed = einops.rearrange(attention_attr_signed, \"sign layer head_index dest src -> (layer head_index sign) dest src\")\n"," attention_attr_signed = attention_attr_signed / attention_attr_signed.max()\n"," attention_attr_indices = attention_attr_signed.max(-1).values.max(-1).values.argsort(descending=True)\n"," # print(attention_attr_indices.shape)\n"," # print(attention_attr_indices)\n"," attention_attr_signed = attention_attr_signed[attention_attr_indices, :, :]\n"," head_labels = [HEAD_NAMES_SIGNED[i.item()] for i in attention_attr_indices]\n","\n"," if title: display(Markdown(\"### \"+title))\n"," display(pysvelte.AttentionMulti(tokens=model.to_str_tokens(tokens), attention=attention_attr_signed.permute(1, 2, 0)[:, :, :top_k], head_labels=head_labels[:top_k]))\n","\n","plot_attention_attr(attention_attr, clean_tokens, index=0, title=\"Attention Attribution for first sequence\")\n","\n","plot_attention_attr(attention_attr.sum(0), clean_tokens[0], title=\"Summed Attention Attribution for all sequences\")\n","print(\"Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Attribution Patching\n"," In the following sections, I will implement various kinds of attribution patching, and then compare them to the activation patching patterns (activation patching code copied from [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo))\n"," ### Residual Stream Patching\n","
Note: We add up across both d_model and batch (Explanation).\n"," We add up along d_model because we're taking the dot product - the derivative *is* the linear map that locally linearly approximates the metric, and so we take the dot product of our change vector with the derivative vector. Equivalent, we look at the effect of changing each coordinate independently, and then combine them by adding it up - it's linear, so this totally works.\n"," We add up across batch because we're taking the average of the metric, so each individual batch element provides `1/batch_size` of the overall effect. Because each batch element is independent of the others and no information moves between activations for different inputs, the batched version is equivalent to doing attribution patching separately for each input, and then averaging - in this second version the metric per input is *not* divided by batch_size because we don't average.
"]},{"cell_type":"code","execution_count":17,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_residual(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"pos\"]:\n"," clean_residual, residual_labels = clean_cache.accumulated_resid(-1, incl_mid=True, return_labels=True)\n"," corrupted_residual = corrupted_cache.accumulated_resid(-1, incl_mid=True, return_labels=False)\n"," corrupted_grad_residual = corrupted_grad_cache.accumulated_resid(-1, incl_mid=True, return_labels=False)\n"," residual_attr = einops.reduce(\n"," corrupted_grad_residual * (clean_residual - corrupted_residual),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\"\n"," )\n"," return residual_attr, residual_labels\n","\n","residual_attr, residual_labels = attr_patch_residual(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(residual_attr, y=residual_labels, yaxis=\"Component\", xaxis=\"Position\", title=\"Residual Attribution Patching\")\n","\n","# ### Layer Output Patching"]},{"cell_type":"code","execution_count":18,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_layer_out(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"pos\"]:\n"," clean_layer_out, labels = clean_cache.decompose_resid(-1, return_labels=True)\n"," corrupted_layer_out = corrupted_cache.decompose_resid(-1, return_labels=False)\n"," corrupted_grad_layer_out = corrupted_grad_cache.decompose_resid(-1, return_labels=False)\n"," layer_out_attr = einops.reduce(\n"," corrupted_grad_layer_out * (clean_layer_out - corrupted_layer_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\"\n"," )\n"," return layer_out_attr, labels\n","\n","layer_out_attr, layer_out_labels = attr_patch_layer_out(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(layer_out_attr, y=layer_out_labels, yaxis=\"Component\", xaxis=\"Position\", title=\"Layer Output Attribution Patching\")"]},{"cell_type":"code","execution_count":19,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_head_out(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_out = clean_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_head_out = corrupted_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_grad_head_out = corrupted_grad_cache.stack_head_results(-1, return_labels=False)\n"," head_out_attr = einops.reduce(\n"," corrupted_grad_head_out * (clean_head_out - corrupted_head_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\"\n"," )\n"," return head_out_attr, labels\n","\n","head_out_attr, head_out_labels = attr_patch_head_out(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(head_out_attr, y=head_out_labels, yaxis=\"Component\", xaxis=\"Position\", title=\"Head Output Attribution Patching\")\n","sum_head_out_attr = einops.reduce(head_out_attr, \"(layer head) pos -> layer head\", \"sum\", layer=model.cfg.n_layers, head=model.cfg.n_heads)\n","imshow(sum_head_out_attr, yaxis=\"Layer\", xaxis=\"Head Index\", title=\"Head Output Attribution Patching Sum Over Pos\")"]},{"cell_type":"markdown","metadata":{},"source":[" ### Head Activation Patching\n"," Intuitively, a head has three inputs, keys, queries and values. We can patch each of these individually to get a sense for where the important part of each head's input comes from!\n"," As a sanity check, we also do this for the mixed value. The result is a linear map of this (`z @ W_O == result`), so this is the same as patching the output of the head.\n"," We plot both the patch for each head over each position, and summed over position (it tends to be pretty sparse, so the latter is the same)"]},{"cell_type":"code","execution_count":20,"metadata":{},"outputs":[{"data":{"text/markdown":["#### Key Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Query Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Mixed Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","def stack_head_vector_from_cache(\n"," cache, \n"," activation_name: Literal[\"q\", \"k\", \"v\", \"z\"]\n"," ) -> TT[\"layer_and_head_index\", \"batch\", \"pos\", \"d_head\"]:\n"," \"\"\"Stacks the head vectors from the cache from a specific activation (key, query, value or mixed_value (z)) into a single tensor.\"\"\"\n"," stacked_head_vectors = torch.stack([cache[activation_name, l] for l in range(model.cfg.n_layers)], dim=0)\n"," stacked_head_vectors = einops.rearrange(\n"," stacked_head_vectors,\n"," \"layer batch pos head_index d_head -> (layer head_index) batch pos d_head\"\n"," )\n"," return stacked_head_vectors\n","\n","def attr_patch_head_vector(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," activation_name: Literal[\"q\", \"k\", \"v\", \"z\"],\n"," ) -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_vector = stack_head_vector_from_cache(clean_cache, activation_name)\n"," corrupted_head_vector = stack_head_vector_from_cache(corrupted_cache, activation_name)\n"," corrupted_grad_head_vector = stack_head_vector_from_cache(corrupted_grad_cache, activation_name)\n"," head_vector_attr = einops.reduce(\n"," corrupted_grad_head_vector * (clean_head_vector - corrupted_head_vector),\n"," \"component batch pos d_head -> component pos\",\n"," \"sum\"\n"," )\n"," return head_vector_attr, labels\n","\n","head_vector_attr_dict = {}\n","for activation_name, activation_name_full in [(\"k\", \"Key\"), (\"q\", \"Query\"), (\"v\", \"Value\"), (\"z\", \"Mixed Value\")]:\n"," display(Markdown(f\"#### {activation_name_full} Head Vector Attribution Patching\"))\n"," head_vector_attr_dict[activation_name], head_vector_labels = attr_patch_head_vector(clean_cache, corrupted_cache, corrupted_grad_cache, activation_name)\n"," imshow(head_vector_attr_dict[activation_name], y=head_vector_labels, yaxis=\"Component\", xaxis=\"Position\", title=f\"{activation_name_full} Attribution Patching\")\n"," sum_head_vector_attr = einops.reduce(head_vector_attr_dict[activation_name], \"(layer head) pos -> layer head\", \"sum\", layer=model.cfg.n_layers, head=model.cfg.n_heads)\n"," imshow(sum_head_vector_attr, yaxis=\"Layer\", xaxis=\"Head Index\", title=f\"{activation_name_full} Attribution Patching Sum Over Pos\")"]},{"cell_type":"code","execution_count":21,"metadata":{},"outputs":[{"data":{"text/markdown":["### Head Pattern Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","def stack_head_pattern_from_cache(\n"," cache, \n"," ) -> TT[\"layer_and_head_index\", \"batch\", \"dest_pos\", \"src_pos\"]:\n"," \"\"\"Stacks the head patterns from the cache into a single tensor.\"\"\"\n"," stacked_head_pattern = torch.stack([cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0)\n"," stacked_head_pattern = einops.rearrange(\n"," stacked_head_pattern,\n"," \"layer batch head_index dest_pos src_pos -> (layer head_index) batch dest_pos src_pos\"\n"," )\n"," return stacked_head_pattern\n","\n","def attr_patch_head_pattern(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"dest_pos\", \"src_pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_pattern = stack_head_pattern_from_cache(clean_cache)\n"," corrupted_head_pattern = stack_head_pattern_from_cache(corrupted_cache)\n"," corrupted_grad_head_pattern = stack_head_pattern_from_cache(corrupted_grad_cache)\n"," head_pattern_attr = einops.reduce(\n"," corrupted_grad_head_pattern * (clean_head_pattern - corrupted_head_pattern),\n"," \"component batch dest_pos src_pos -> component dest_pos src_pos\",\n"," \"sum\"\n"," )\n"," return head_pattern_attr, labels\n","\n","head_pattern_attr, labels = attr_patch_head_pattern(clean_cache, corrupted_cache, corrupted_grad_cache)\n","\n","plot_attention_attr(einops.rearrange(head_pattern_attr, \"(layer head) dest src -> layer head dest src\", layer=model.cfg.n_layers, head=model.cfg.n_heads), clean_tokens, index=0, title=\"Head Pattern Attribution Patching\")"]},{"cell_type":"code","execution_count":22,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_head_vector_grad_input_from_grad_cache(\n"," grad_cache: ActivationCache, \n"," activation_name: Literal[\"q\", \"k\", \"v\"],\n"," layer: int\n"," ) -> TT[\"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," vector_grad = grad_cache[activation_name, layer]\n"," ln_scales = grad_cache[\"scale\", layer, \"ln1\"]\n"," attn_layer_object = model.blocks[layer].attn\n"," if activation_name == \"q\":\n"," W = attn_layer_object.W_Q\n"," elif activation_name == \"k\":\n"," W = attn_layer_object.W_K\n"," elif activation_name == \"v\":\n"," W = attn_layer_object.W_V\n"," else:\n"," raise ValueError(\"Invalid activation name\")\n","\n"," return einsum(\"batch pos head_index d_head, batch pos, head_index d_model d_head -> batch pos head_index d_model\", vector_grad, ln_scales.squeeze(-1), W)\n","\n","def get_stacked_head_vector_grad_input(grad_cache, activation_name: Literal[\"q\", \"k\", \"v\"]) -> TT[\"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack([get_head_vector_grad_input_from_grad_cache(grad_cache, activation_name, l) for l in range(model.cfg.n_layers)], dim=0)\n","\n","def get_full_vector_grad_input(grad_cache) -> TT[\"qkv\", \"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack([get_stacked_head_vector_grad_input(grad_cache, activation_name) for activation_name in ['q', 'k', 'v']], dim=0)\n","\n","def attr_patch_head_path(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache\n"," ) -> TT[\"qkv\", \"dest_component\", \"src_component\", \"pos\"]:\n"," \"\"\"\n"," Computes the attribution patch along the path between each pair of heads.\n","\n"," Sets this to zero for the path from any late head to any early head\n","\n"," \"\"\"\n"," start_labels = HEAD_NAMES\n"," end_labels = HEAD_NAMES_QKV\n"," full_vector_grad_input = get_full_vector_grad_input(corrupted_grad_cache)\n"," clean_head_result_stack = clean_cache.stack_head_results(-1)\n"," corrupted_head_result_stack = corrupted_cache.stack_head_results(-1)\n"," diff_head_result = einops.rearrange(\n"," clean_head_result_stack - corrupted_head_result_stack,\n"," \"(layer head_index) batch pos d_model -> layer batch pos head_index d_model\",\n"," layer = model.cfg.n_layers,\n"," head_index = model.cfg.n_heads,\n"," )\n"," path_attr = einsum(\n"," \"qkv layer_end batch pos head_end d_model, layer_start batch pos head_start d_model -> qkv layer_end head_end layer_start head_start pos\", \n"," full_vector_grad_input, \n"," diff_head_result)\n"," correct_layer_order_mask = (\n"," torch.arange(model.cfg.n_layers)[None, :, None, None, None, None] > \n"," torch.arange(model.cfg.n_layers)[None, None, None, :, None, None]).to(path_attr.device)\n"," zero = torch.zeros(1, device=path_attr.device)\n"," path_attr = torch.where(correct_layer_order_mask, path_attr, zero)\n","\n"," path_attr = einops.rearrange(\n"," path_attr,\n"," \"qkv layer_end head_end layer_start head_start pos -> (layer_end head_end qkv) (layer_start head_start) pos\",\n"," )\n"," return path_attr, end_labels, start_labels\n","\n","head_path_attr, end_labels, start_labels = attr_patch_head_path(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(head_path_attr.sum(-1), y=end_labels, yaxis=\"Path End (Head Input)\", x=start_labels, xaxis=\"Path Start (Head Output)\", title=\"Head Path Attribution Patching\")"]},{"cell_type":"markdown","metadata":{},"source":[" This is hard to parse. Here's an experiment with filtering for the most important heads and showing their paths."]},{"cell_type":"code","execution_count":23,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["head_out_values, head_out_indices = head_out_attr.sum(-1).abs().sort(descending=True)\n","line(head_out_values)\n","top_head_indices = head_out_indices[:22].sort().values\n","top_end_indices = []\n","top_end_labels = []\n","top_start_indices = []\n","top_start_labels = []\n","for i in top_head_indices:\n"," i = i.item()\n"," top_start_indices.append(i)\n"," top_start_labels.append(start_labels[i])\n"," for j in range(3):\n"," top_end_indices.append(3*i+j)\n"," top_end_labels.append(end_labels[3*i+j])\n","\n","imshow(head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1), y=top_end_labels, yaxis=\"Path End (Head Input)\", x=top_start_labels, xaxis=\"Path Start (Head Output)\", title=\"Head Path Attribution Patching (Filtered for Top Heads)\")"]},{"cell_type":"code","execution_count":24,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["for j, composition_type in enumerate([\"Query\", \"Key\", \"Value\"]):\n"," imshow(head_path_attr[top_end_indices, :][:, top_start_indices][j::3].sum(-1), y=top_end_labels[j::3], yaxis=\"Path End (Head Input)\", x=top_start_labels, xaxis=\"Path Start (Head Output)\", title=f\"Head Path to {composition_type} Attribution Patching (Filtered for Top Heads)\")"]},{"cell_type":"code","execution_count":25,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["top_head_path_attr = einops.rearrange(head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1), \"(head_end qkv) head_start -> qkv head_end head_start\", qkv=3)\n","imshow(top_head_path_attr, y=[i[:-1] for i in top_end_labels[::3]], yaxis=\"Path End (Head Input)\", x=top_start_labels, xaxis=\"Path Start (Head Output)\", title=f\"Head Path Attribution Patching (Filtered for Top Heads)\", facet_col=0, facet_labels=[\"Query\", \"Key\", \"Value\"])"]},{"cell_type":"markdown","metadata":{},"source":[" Let's now dive into 3 interesting heads: L5H5 (induction head), L8H6 (S-Inhibition Head), L9H9 (Name Mover) and look at their input and output paths (note - Q input means )"]},{"cell_type":"code","execution_count":26,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["interesting_heads = [5 * model.cfg.n_heads + 5, 8 * model.cfg.n_heads + 6, 9 * model.cfg.n_heads + 9]\n","interesting_head_labels = [HEAD_NAMES[i] for i in interesting_heads]\n","for head_index, label in zip(interesting_heads, interesting_head_labels):\n"," in_paths = head_path_attr[3*head_index:3*head_index+3].sum(-1)\n"," out_paths = head_path_attr[:, head_index].sum(-1)\n"," out_paths = einops.rearrange(out_paths, \"(layer_head qkv) -> qkv layer_head\", qkv=3)\n"," all_paths = torch.cat([in_paths, out_paths], dim=0)\n"," all_paths = einops.rearrange(all_paths, \"path_type (layer head) -> path_type layer head\", layer=model.cfg.n_layers, head=model.cfg.n_heads)\n"," imshow(all_paths, facet_col=0, facet_labels=[\"Query (In)\", \"Key (In)\", \"Value (In)\", \"Query (Out)\", \"Key (Out)\", \"Value (Out)\"], title=f\"Input and Output Paths for head {label}\", yaxis=\"Layer\", xaxis=\"Head\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Validating Attribution vs Activation Patching\n"," Let's now compare attribution and activation patching. Generally it's a decent approximation! The main place it fails is MLP0 and the residual stream\n"," My fuzzy intuition is that attribution patching works badly for \"big\" things which are poorly modelled as linear approximations, and works well for \"small\" things which are more like incremental changes. Anything involving replacing the embedding is a \"big\" thing, which includes residual streams, and in GPT-2 small MLP0 seems to be used as an \"extended embedding\" (where later layers use MLP0's output instead of the token embedding), so I also count it as big.\n"," See more discussion in the accompanying blog post!\n"]},{"cell_type":"markdown","metadata":{},"source":[" First do some refactoring to make attribution patching more generic. We make an attribution cache, which is an ActivationCache where each element is (clean_act - corrupted_act) * corrupted_grad, so that it's the per-element attribution for each activation. Thanks to linearity, we just compute things by adding stuff up along the relevant dimensions!"]},{"cell_type":"code","execution_count":27,"metadata":{},"outputs":[],"source":["attribution_cache_dict = {}\n","for key in corrupted_grad_cache.cache_dict.keys():\n"," attribution_cache_dict[key] = corrupted_grad_cache.cache_dict[key] * (clean_cache.cache_dict[key] - corrupted_cache.cache_dict[key])\n","attr_cache = ActivationCache(attribution_cache_dict, model)"]},{"cell_type":"markdown","metadata":{},"source":[" By block: For each head we patch the starting residual stream, attention output + MLP output"]},{"cell_type":"code","execution_count":28,"metadata":{},"outputs":[],"source":["str_tokens = model.to_str_tokens(clean_tokens[0])\n","context_length = len(str_tokens)"]},{"cell_type":"code","execution_count":29,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"95a5290e11b64b6a95ef5dd37d027c7a","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_block_act_patch_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(every_block_act_patch_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Activation Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])"]},{"cell_type":"code","execution_count":30,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_block_every(attr_cache):\n"," resid_pre_attr = einops.reduce(\n"," attr_cache.stack_activation(\"resid_pre\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," attn_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"attn_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," mlp_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"mlp_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n","\n"," every_block_attr_patch_result = torch.stack([resid_pre_attr, attn_out_attr, mlp_out_attr], dim=0)\n"," return every_block_attr_patch_result\n","every_block_attr_patch_result = get_attr_patch_block_every(attr_cache)\n","imshow(every_block_attr_patch_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Attribution Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])"]},{"cell_type":"code","execution_count":31,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(y=every_block_attr_patch_result.reshape(3, -1), x=every_block_act_patch_result.reshape(3, -1), facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Attribution vs Activation Patching Per Block\", xaxis=\"Activation Patch\", yaxis=\"Attribution Patch\", hover=[f\"Layer {l}, Position {p}, |{str_tokens[p]}|\" for l in range(model.cfg.n_layers) for p in range(context_length)], color=einops.repeat(torch.arange(model.cfg.n_layers), \"layer -> (layer pos)\", pos=context_length), color_continuous_scale=\"Portland\")"]},{"cell_type":"markdown","metadata":{},"source":[" By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow."]},{"cell_type":"code","execution_count":32,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"18b2e6b0985b40cd8c0cd1a16ba62975","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=1, zmin=-1)"]},{"cell_type":"code","execution_count":33,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_all_pos_every(attr_cache):\n"," head_out_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_q_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_k_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_v_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack([head_out_all_pos_attr, head_q_all_pos_attr, head_k_all_pos_attr, head_v_all_pos_attr, head_pattern_all_pos_attr])\n"," \n","every_head_all_pos_attr_patch_result = get_attr_patch_attn_head_all_pos_every(attr_cache)\n","imshow(every_head_all_pos_attr_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=1, zmin=-1)"]},{"cell_type":"code","execution_count":34,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(y=every_head_all_pos_attr_patch_result.reshape(5, -1), x=every_head_all_pos_act_patch_result.reshape(5, -1), facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution vs Activation Patching Per Head (All Pos)\", xaxis=\"Activation Patch\", yaxis=\"Attribution Patch\", include_diag=True, hover=head_out_labels, color=einops.repeat(torch.arange(model.cfg.n_layers), \"layer -> (layer head)\", head=model.cfg.n_heads), color_continuous_scale=\"Portland\")"]},{"cell_type":"markdown","metadata":{},"source":[" We see pretty good results in general, but significant errors for heads L5H5 on query and moderate errors for head L10H7 on query and key, and moderate errors for head L11H10 on key. But each of these is fine for pattern and output. My guess is that the problem is that these have pretty saturated attention on a single token, and the linear approximation is thus not great on the attention calculation here, but I'm not sure. When we plot the attention patterns, we do see this!\n"," Note that the axis labels are for the *first* prompt's tokens, but each facet is a different prompt, so this is somewhat inaccurate. In particular, every odd facet has indirect object and subject in the opposite order (IO first). But otherwise everything lines up between the prompts"]},{"cell_type":"code","execution_count":35,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["graph_tok_labels = [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]\n","imshow(clean_cache[\"pattern\", 5][:, 5], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title=\"Attention for Head L5H5\", facet_name=\"Prompt\")\n","imshow(clean_cache[\"pattern\", 10][:, 7], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title=\"Attention for Head L10H7\", facet_name=\"Prompt\")\n","imshow(clean_cache[\"pattern\", 11][:, 10], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title=\"Attention for Head L11H10\", facet_name=\"Prompt\")\n","\n","\n","# [markdown]"]},{"cell_type":"code","execution_count":36,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"06f39489001845849fbc7446a07066f4","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/2160 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_by_pos_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","every_head_by_pos_act_patch_result = einops.rearrange(every_head_by_pos_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n","imshow(every_head_by_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=head_out_labels)"]},{"cell_type":"code","execution_count":37,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_by_pos_every(attr_cache):\n"," head_out_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_q_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_k_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_v_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer dest_pos head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack([head_out_by_pos_attr, head_q_by_pos_attr, head_k_by_pos_attr, head_v_by_pos_attr, head_pattern_by_pos_attr])\n","every_head_by_pos_attr_patch_result = get_attr_patch_attn_head_by_pos_every(attr_cache)\n","every_head_by_pos_attr_patch_result = einops.rearrange(every_head_by_pos_attr_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n","imshow(every_head_by_pos_attr_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=head_out_labels)"]},{"cell_type":"code","execution_count":38,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(y=every_head_by_pos_attr_patch_result.reshape(5, -1), x=every_head_by_pos_act_patch_result.reshape(5, -1), facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution vs Activation Patching Per Head (by Pos)\", xaxis=\"Activation Patch\", yaxis=\"Attribution Patch\", include_diag=True, hover=[f\"{label} {tok}\" for label in head_out_labels for tok in graph_tok_labels], color=einops.repeat(torch.arange(model.cfg.n_layers), \"layer -> (layer head pos)\", head=model.cfg.n_heads, pos = 15), color_continuous_scale=\"Portland\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Factual Knowledge Patching Example\n"," Incomplete, but maybe of interest!\n"," Note that I have better results with the corrupted prompt as having random words rather than Colosseum."]},{"cell_type":"code","execution_count":39,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-xl into HookedTransformer\n","Tokenized prompt: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Paris']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.73 Prob: 95.80% Token: | Paris|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.73\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m95.80\u001b[0m\u001b[1m% Token: | Paris|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.73 Prob: 95.80% Token: | Paris|\n","Top 1th token. Logit: 16.49 Prob: 1.39% Token: | E|\n","Top 2th token. Logit: 14.69 Prob: 0.23% Token: | the|\n","Top 3th token. Logit: 14.58 Prob: 0.21% Token: | É|\n","Top 4th token. Logit: 14.44 Prob: 0.18% Token: | France|\n","Top 5th token. Logit: 14.36 Prob: 0.16% Token: | Mont|\n","Top 6th token. Logit: 13.77 Prob: 0.09% Token: | Le|\n","Top 7th token. Logit: 13.66 Prob: 0.08% Token: | Ang|\n","Top 8th token. Logit: 13.43 Prob: 0.06% Token: | V|\n","Top 9th token. Logit: 13.42 Prob: 0.06% Token: | Stras|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Paris', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Paris'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Tokenized prompt: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Rome']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.02 Prob: 83.70% Token: | Rome|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.02\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m83.70\u001b[0m\u001b[1m% Token: | Rome|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.02 Prob: 83.70% Token: | Rome|\n","Top 1th token. Logit: 17.03 Prob: 4.23% Token: | Naples|\n","Top 2th token. Logit: 16.85 Prob: 3.51% Token: | Pompe|\n","Top 3th token. Logit: 16.14 Prob: 1.73% Token: | Ver|\n","Top 4th token. Logit: 15.87 Prob: 1.32% Token: | Florence|\n","Top 5th token. Logit: 14.77 Prob: 0.44% Token: | Roma|\n","Top 6th token. Logit: 14.68 Prob: 0.40% Token: | Milan|\n","Top 7th token. Logit: 14.66 Prob: 0.39% Token: | ancient|\n","Top 8th token. Logit: 14.37 Prob: 0.29% Token: | Pal|\n","Top 9th token. Logit: 14.30 Prob: 0.27% Token: | Constantinople|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Rome', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Rome'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"}],"source":["gpt2_xl = HookedTransformer.from_pretrained(\"gpt2-xl\")\n","clean_prompt = \"The Eiffel Tower is located in the city of\"\n","clean_answer = \" Paris\"\n","# corrupted_prompt = \"The red brown fox jumps is located in the city of\"\n","corrupted_prompt = \"The Colosseum is located in the city of\"\n","corrupted_answer = \" Rome\"\n","utils.test_prompt(clean_prompt, clean_answer, gpt2_xl)\n","utils.test_prompt(corrupted_prompt, corrupted_answer, gpt2_xl)"]},{"cell_type":"code","execution_count":40,"metadata":{},"outputs":[],"source":["clean_answer_index = gpt2_xl.to_single_token(clean_answer)\n","corrupted_answer_index = gpt2_xl.to_single_token(corrupted_answer)\n","def factual_logit_diff(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return logits[0, -1, clean_answer_index] - logits[0, -1, corrupted_answer_index]"]},{"cell_type":"code","execution_count":41,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 10.634519577026367\n","Corrupted logit diff: -8.988396644592285\n","Clean Metric: tensor(1., device='cuda:0', grad_fn=)\n","Corrupted Metric: tensor(0., device='cuda:0', grad_fn=)\n"]}],"source":["clean_logits, clean_cache = gpt2_xl.run_with_cache(clean_prompt)\n","CLEAN_LOGIT_DIFF_FACTUAL = factual_logit_diff(clean_logits).item()\n","corrupted_logits, _ = gpt2_xl.run_with_cache(corrupted_prompt)\n","CORRUPTED_LOGIT_DIFF_FACTUAL = factual_logit_diff(corrupted_logits).item()\n","\n","def factual_metric(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return (factual_logit_diff(logits) - CORRUPTED_LOGIT_DIFF_FACTUAL) / (CLEAN_LOGIT_DIFF_FACTUAL - CORRUPTED_LOGIT_DIFF_FACTUAL)\n","print(\"Clean logit diff:\", CLEAN_LOGIT_DIFF_FACTUAL)\n","print(\"Corrupted logit diff:\", CORRUPTED_LOGIT_DIFF_FACTUAL)\n","print(\"Clean Metric:\", factual_metric(clean_logits))\n","print(\"Corrupted Metric:\", factual_metric(corrupted_logits))"]},{"cell_type":"code","execution_count":42,"metadata":{},"outputs":[],"source":["# corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(gpt2_xl, corrupted_prompt, factual_metric)"]},{"cell_type":"code","execution_count":43,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Corrupted: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n"]}],"source":["clean_tokens = gpt2_xl.to_tokens(clean_prompt)\n","clean_str_tokens = gpt2_xl.to_str_tokens(clean_prompt)\n","corrupted_tokens = gpt2_xl.to_tokens(corrupted_prompt)\n","corrupted_str_tokens = gpt2_xl.to_str_tokens(corrupted_prompt)\n","print(\"Clean:\", clean_str_tokens)\n","print(\"Corrupted:\", corrupted_str_tokens)"]},{"cell_type":"code","execution_count":44,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"b767eef7a3cd49b9b3cb6e5301463f08","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/48 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def act_patch_residual(clean_cache, corrupted_tokens, model: HookedTransformer, metric):\n"," if len(corrupted_tokens.shape)==2:\n"," corrupted_tokens = corrupted_tokens[0]\n"," residual_patches = torch.zeros((model.cfg.n_layers, len(corrupted_tokens)), device=model.cfg.device)\n"," def residual_hook(resid_pre, hook, layer, pos):\n"," resid_pre[:, pos, :] = clean_cache[\"resid_pre\", layer][:, pos, :]\n"," return resid_pre\n"," for layer in tqdm.tqdm(range(model.cfg.n_layers)):\n"," for pos in range(len(corrupted_tokens)):\n"," patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[(f\"blocks.{layer}.hook_resid_pre\", partial(residual_hook, layer=layer, pos=pos))])\n"," residual_patches[layer, pos] = metric(patched_logits).item()\n"," return residual_patches\n","\n","residual_act_patch = act_patch_residual(clean_cache, corrupted_tokens, gpt2_xl, factual_metric)\n","\n","imshow(residual_act_patch, title=\"Factual Recall Patching (Residual)\", xaxis=\"Position\", yaxis=\"Layer\", x=clean_str_tokens)"]}],"metadata":{"kernelspec":{"display_name":"base","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.7.13"},"orig_nbformat":4,"vscode":{"interpreter":{"hash":"d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"}}},"nbformat":4,"nbformat_minor":2} +{"cells":[{"cell_type":"markdown","metadata":{},"source":["\n"," \"Open\n",""]},{"cell_type":"markdown","metadata":{},"source":[" # Attribution Patching Demo\n"," **Read [the accompanying blog post here](https://neelnanda.io/attribution-patching) for more context**\n"," This is an interim research report, giving a whirlwind tour of some unpublished work I did at Anthropic (credit to the then team - Chris Olah, Catherine Olsson, Nelson Elhage and Tristan Hume for help, support, and mentorship!)\n","\n"," The goal of this work is run activation patching at an industrial scale, by using gradient based attribution to approximate the technique - allow an arbitrary number of patches to be made on two forwards and a single backward pass\n","\n"," I have had less time than hoped to flesh out this investigation, but am writing up a rough investigation and comparison to standard activation patching on a few tasks to give a sense of the potential of this approach, and where it works vs falls down."]},{"cell_type":"markdown","metadata":{},"source":[" To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n","\n"," **Tips for reading this Colab:**\n"," * You can run all this code for yourself!\n"," * The graphs are interactive!\n"," * Use the table of contents pane in the sidebar to navigate\n"," * Collapse irrelevant sections with the dropdown arrows\n"," * Search the page using the search in the sidebar, not CTRL+F"]},{"cell_type":"markdown","metadata":{},"source":[" ## Setup (Ignore)"]},{"cell_type":"code","execution_count":1,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Running as a Jupyter notebook - intended for development only!\n"]},{"name":"stderr","output_type":"stream","text":["/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_25358/2480103146.py:24: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n"," ipython.magic(\"load_ext autoreload\")\n","/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_25358/2480103146.py:25: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n"," ipython.magic(\"autoreload 2\")\n"]}],"source":["# Janky code to do different setup when run in a Colab notebook vs VSCode\n","import os\n","\n","DEBUG_MODE = False\n","IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n","try:\n"," import google.colab\n","\n"," IN_COLAB = True\n"," print(\"Running as a Colab notebook\")\n","except:\n"," IN_COLAB = False\n"," print(\"Running as a Jupyter notebook - intended for development only!\")\n"," from IPython import get_ipython\n","\n"," ipython = get_ipython()\n"," # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n"," ipython.magic(\"load_ext autoreload\")\n"," ipython.magic(\"autoreload 2\")\n","\n","if IN_COLAB or IN_GITHUB:\n"," %pip install transformer_lens\n"," %pip install torchtyping\n"," # Install my janky personal plotting utils\n"," %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n"," # Install another version of node that makes PySvelte work way faster\n"," %pip install circuitsvis\n"," # Needed for PySvelte to work, v3 came out and broke things...\n"," %pip install typeguard==2.13.3"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n","import plotly.io as pio\n","\n","if IN_COLAB or not DEBUG_MODE:\n"," # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n"," pio.renderers.default = \"colab\"\n","else:\n"," pio.renderers.default = \"notebook_connected\""]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[{"ename":"ModuleNotFoundError","evalue":"No module named 'torchtyping'","output_type":"error","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)","Cell \u001b[0;32mIn[3], line 15\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mplotly\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mexpress\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpx\u001b[39;00m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DataLoader\n\u001b[0;32m---> 15\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorchtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m TensorType \u001b[38;5;28;01mas\u001b[39;00m TT\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m List, Union, Optional, Callable\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfunctools\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m partial\n","\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torchtyping'"]}],"source":["# Import stuff\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","import numpy as np\n","import einops\n","from fancy_einsum import einsum\n","import tqdm.notebook as tqdm\n","import random\n","from pathlib import Path\n","import plotly.express as px\n","from torch.utils.data import DataLoader\n","\n","from torchtyping import TensorType as TT\n","from typing import List, Union, Optional, Callable\n","from functools import partial\n","import copy\n","import itertools\n","import json\n","\n","from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n","import dataclasses\n","import datasets\n","from IPython.display import HTML, Markdown"]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[],"source":["import transformer_lens\n","import transformer_lens.utils as utils\n","from transformer_lens.hook_points import (\n"," HookedRootModule,\n"," HookPoint,\n",") # Hooking utilities\n","from transformer_lens import (\n"," HookedTransformer,\n"," HookedTransformerConfig,\n"," FactoredMatrix,\n"," ActivationCache,\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[],"source":["from neel_plotly import line, imshow, scatter"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":["import transformer_lens.patching as patching"]},{"cell_type":"markdown","metadata":{},"source":[" ## IOI Patching Setup\n"," This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important."]},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-small into HookedTransformer\n"]}],"source":["model = HookedTransformer.from_pretrained(\"gpt2-small\")\n","model.set_use_attn_result(True)"]},{"cell_type":"code","execution_count":9,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to\n","Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to\n","Answer token indices tensor([[ 5335, 1757],\n"," [ 1757, 5335],\n"," [ 4186, 3700],\n"," [ 3700, 4186],\n"," [ 6035, 15686],\n"," [15686, 6035],\n"," [ 5780, 14235],\n"," [14235, 5780]], device='cuda:0')\n"]}],"source":["prompts = [\n"," \"When John and Mary went to the shops, John gave the bag to\",\n"," \"When John and Mary went to the shops, Mary gave the bag to\",\n"," \"When Tom and James went to the park, James gave the ball to\",\n"," \"When Tom and James went to the park, Tom gave the ball to\",\n"," \"When Dan and Sid went to the shops, Sid gave an apple to\",\n"," \"When Dan and Sid went to the shops, Dan gave an apple to\",\n"," \"After Martin and Amy went to the park, Amy gave a drink to\",\n"," \"After Martin and Amy went to the park, Martin gave a drink to\",\n","]\n","answers = [\n"," (\" Mary\", \" John\"),\n"," (\" John\", \" Mary\"),\n"," (\" Tom\", \" James\"),\n"," (\" James\", \" Tom\"),\n"," (\" Dan\", \" Sid\"),\n"," (\" Sid\", \" Dan\"),\n"," (\" Martin\", \" Amy\"),\n"," (\" Amy\", \" Martin\"),\n","]\n","\n","clean_tokens = model.to_tokens(prompts)\n","# Swap each adjacent pair, with a hacky list comprehension\n","corrupted_tokens = clean_tokens[\n"," [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]\n","]\n","print(\"Clean string 0\", model.to_string(clean_tokens[0]))\n","print(\"Corrupted string 0\", model.to_string(corrupted_tokens[0]))\n","\n","answer_token_indices = torch.tensor(\n"," [\n"," [model.to_single_token(answers[i][j]) for j in range(2)]\n"," for i in range(len(answers))\n"," ],\n"," device=model.cfg.device,\n",")\n","print(\"Answer token indices\", answer_token_indices)"]},{"cell_type":"code","execution_count":10,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 3.5519\n","Corrupted logit diff: -3.5519\n"]}],"source":["def get_logit_diff(logits, answer_token_indices=answer_token_indices):\n"," if len(logits.shape) == 3:\n"," # Get final logits only\n"," logits = logits[:, -1, :]\n"," correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))\n"," incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))\n"," return (correct_logits - incorrect_logits).mean()\n","\n","\n","clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n","corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n","\n","clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()\n","print(f\"Clean logit diff: {clean_logit_diff:.4f}\")\n","\n","corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()\n","print(f\"Corrupted logit diff: {corrupted_logit_diff:.4f}\")"]},{"cell_type":"code","execution_count":11,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Baseline is 1: 1.0000\n","Corrupted Baseline is 0: 0.0000\n"]}],"source":["CLEAN_BASELINE = clean_logit_diff\n","CORRUPTED_BASELINE = corrupted_logit_diff\n","\n","\n","def ioi_metric(logits, answer_token_indices=answer_token_indices):\n"," return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (\n"," CLEAN_BASELINE - CORRUPTED_BASELINE\n"," )\n","\n","\n","print(f\"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}\")\n","print(f\"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Patching\n"," In the following cells, we define attribution patching and use it in various ways on the model."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["Metric = Callable[[TT[\"batch_and_pos_dims\", \"d_model\"]], float]"]},{"cell_type":"code","execution_count":13,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Value: 1.0\n","Clean Activations Cached: 220\n","Clean Gradients Cached: 220\n","Corrupted Value: 0.0\n","Corrupted Activations Cached: 220\n","Corrupted Gradients Cached: 220\n"]}],"source":["filter_not_qkv_input = lambda name: \"_input\" not in name\n","\n","\n","def get_cache_fwd_and_bwd(model, tokens, metric):\n"," model.reset_hooks()\n"," cache = {}\n","\n"," def forward_cache_hook(act, hook):\n"," cache[hook.name] = act.detach()\n","\n"," model.add_hook(filter_not_qkv_input, forward_cache_hook, \"fwd\")\n","\n"," grad_cache = {}\n","\n"," def backward_cache_hook(act, hook):\n"," grad_cache[hook.name] = act.detach()\n","\n"," model.add_hook(filter_not_qkv_input, backward_cache_hook, \"bwd\")\n","\n"," value = metric(model(tokens))\n"," value.backward()\n"," model.reset_hooks()\n"," return (\n"," value.item(),\n"," ActivationCache(cache, model),\n"," ActivationCache(grad_cache, model),\n"," )\n","\n","\n","clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(\n"," model, clean_tokens, ioi_metric\n",")\n","print(\"Clean Value:\", clean_value)\n","print(\"Clean Activations Cached:\", len(clean_cache))\n","print(\"Clean Gradients Cached:\", len(clean_grad_cache))\n","corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(\n"," model, corrupted_tokens, ioi_metric\n",")\n","print(\"Corrupted Value:\", corrupted_value)\n","print(\"Corrupted Activations Cached:\", len(corrupted_cache))\n","print(\"Corrupted Gradients Cached:\", len(corrupted_grad_cache))"]},{"cell_type":"markdown","metadata":{},"source":[" ### Attention Attribution\n"," The easiest thing to start with is to not even engage with the corrupted tokens/patching, but to look at the attribution of the attention patterns - that is, the linear approximation to what happens if you set each element of the attention pattern to zero. This, as it turns out, is a good proxy to what is going on with each head!\n"," Note that this is *not* the same as what we will later do with patching. In particular, this does not set up a careful counterfactual! It's a good tool for what's generally going on in this problem, but does not control for eg stuff that systematically boosts John > Mary in general, stuff that says \"I should activate the IOI circuit\", etc. Though using logit diff as our metric *does*\n"," Each element of the batch is independent and the metric is an average logit diff, so we can analyse each batch element independently here. We'll look at the first one, and then at the average across the whole batch (note - 4 prompts have indirect object before subject, 4 prompts have it the other way round, making the average pattern harder to interpret - I plot it over the first sequence of tokens as a mildly misleading reference).\n"," We can compare it to the interpretability in the wild diagram, and basically instantly recover most of the circuit!"]},{"cell_type":"code","execution_count":14,"metadata":{},"outputs":[],"source":["def create_attention_attr(\n"," clean_cache, clean_grad_cache\n",") -> TT[\"batch\", \"layer\", \"head_index\", \"dest\", \"src\"]:\n"," attention_stack = torch.stack(\n"," [clean_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," attention_grad_stack = torch.stack(\n"," [clean_grad_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," attention_attr = attention_grad_stack * attention_stack\n"," attention_attr = einops.rearrange(\n"," attention_attr,\n"," \"layer batch head_index dest src -> batch layer head_index dest src\",\n"," )\n"," return attention_attr\n","\n","\n","attention_attr = create_attention_attr(clean_cache, clean_grad_cache)"]},{"cell_type":"code","execution_count":15,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["['L0H0', 'L0H1', 'L0H2', 'L0H3', 'L0H4']\n","['L0H0+', 'L0H0-', 'L0H1+', 'L0H1-', 'L0H2+']\n","['L0H0Q', 'L0H0K', 'L0H0V', 'L0H1Q', 'L0H1K']\n"]}],"source":["HEAD_NAMES = [\n"," f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)\n","]\n","HEAD_NAMES_SIGNED = [f\"{name}{sign}\" for name in HEAD_NAMES for sign in [\"+\", \"-\"]]\n","HEAD_NAMES_QKV = [\n"," f\"{name}{act_name}\" for name in HEAD_NAMES for act_name in [\"Q\", \"K\", \"V\"]\n","]\n","print(HEAD_NAMES[:5])\n","print(HEAD_NAMES_SIGNED[:5])\n","print(HEAD_NAMES_QKV[:5])"]},{"cell_type":"markdown","metadata":{},"source":[" An extremely janky way to plot the attention attribution patterns. We scale them to be in [-1, 1], split each head into a positive and negative part (so all of it is in [0, 1]), and then plot the top 20 head-halves (a head can appear twice!) by the max value of the attribution pattern."]},{"cell_type":"code","execution_count":16,"metadata":{},"outputs":[{"data":{"text/markdown":["### Attention Attribution for first sequence"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["### Summed Attention Attribution for all sequences"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\n"]}],"source":["def plot_attention_attr(attention_attr, tokens, top_k=20, index=0, title=\"\"):\n"," if len(tokens.shape) == 2:\n"," tokens = tokens[index]\n"," if len(attention_attr.shape) == 5:\n"," attention_attr = attention_attr[index]\n"," attention_attr_pos = attention_attr.clamp(min=-1e-5)\n"," attention_attr_neg = -attention_attr.clamp(max=1e-5)\n"," attention_attr_signed = torch.stack([attention_attr_pos, attention_attr_neg], dim=0)\n"," attention_attr_signed = einops.rearrange(\n"," attention_attr_signed,\n"," \"sign layer head_index dest src -> (layer head_index sign) dest src\",\n"," )\n"," attention_attr_signed = attention_attr_signed / attention_attr_signed.max()\n"," attention_attr_indices = (\n"," attention_attr_signed.max(-1).values.max(-1).values.argsort(descending=True)\n"," )\n"," # print(attention_attr_indices.shape)\n"," # print(attention_attr_indices)\n"," attention_attr_signed = attention_attr_signed[attention_attr_indices, :, :]\n"," head_labels = [HEAD_NAMES_SIGNED[i.item()] for i in attention_attr_indices]\n","\n"," if title:\n"," display(Markdown(\"### \" + title))\n"," display(\n"," pysvelte.AttentionMulti(\n"," tokens=model.to_str_tokens(tokens),\n"," attention=attention_attr_signed.permute(1, 2, 0)[:, :, :top_k],\n"," head_labels=head_labels[:top_k],\n"," )\n"," )\n","\n","\n","plot_attention_attr(\n"," attention_attr,\n"," clean_tokens,\n"," index=0,\n"," title=\"Attention Attribution for first sequence\",\n",")\n","\n","plot_attention_attr(\n"," attention_attr.sum(0),\n"," clean_tokens[0],\n"," title=\"Summed Attention Attribution for all sequences\",\n",")\n","print(\n"," \"Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\"\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Attribution Patching\n"," In the following sections, I will implement various kinds of attribution patching, and then compare them to the activation patching patterns (activation patching code copied from [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo))\n"," ### Residual Stream Patching\n","
Note: We add up across both d_model and batch (Explanation).\n"," We add up along d_model because we're taking the dot product - the derivative *is* the linear map that locally linearly approximates the metric, and so we take the dot product of our change vector with the derivative vector. Equivalent, we look at the effect of changing each coordinate independently, and then combine them by adding it up - it's linear, so this totally works.\n"," We add up across batch because we're taking the average of the metric, so each individual batch element provides `1/batch_size` of the overall effect. Because each batch element is independent of the others and no information moves between activations for different inputs, the batched version is equivalent to doing attribution patching separately for each input, and then averaging - in this second version the metric per input is *not* divided by batch_size because we don't average.
"]},{"cell_type":"code","execution_count":17,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_residual(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"pos\"]:\n"," clean_residual, residual_labels = clean_cache.accumulated_resid(\n"," -1, incl_mid=True, return_labels=True\n"," )\n"," corrupted_residual = corrupted_cache.accumulated_resid(\n"," -1, incl_mid=True, return_labels=False\n"," )\n"," corrupted_grad_residual = corrupted_grad_cache.accumulated_resid(\n"," -1, incl_mid=True, return_labels=False\n"," )\n"," residual_attr = einops.reduce(\n"," corrupted_grad_residual * (clean_residual - corrupted_residual),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\",\n"," )\n"," return residual_attr, residual_labels\n","\n","\n","residual_attr, residual_labels = attr_patch_residual(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," residual_attr,\n"," y=residual_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=\"Residual Attribution Patching\",\n",")\n","\n","# ### Layer Output Patching"]},{"cell_type":"code","execution_count":18,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_layer_out(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"pos\"]:\n"," clean_layer_out, labels = clean_cache.decompose_resid(-1, return_labels=True)\n"," corrupted_layer_out = corrupted_cache.decompose_resid(-1, return_labels=False)\n"," corrupted_grad_layer_out = corrupted_grad_cache.decompose_resid(\n"," -1, return_labels=False\n"," )\n"," layer_out_attr = einops.reduce(\n"," corrupted_grad_layer_out * (clean_layer_out - corrupted_layer_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\",\n"," )\n"," return layer_out_attr, labels\n","\n","\n","layer_out_attr, layer_out_labels = attr_patch_layer_out(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," layer_out_attr,\n"," y=layer_out_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=\"Layer Output Attribution Patching\",\n",")"]},{"cell_type":"code","execution_count":19,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_head_out(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_out = clean_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_head_out = corrupted_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_grad_head_out = corrupted_grad_cache.stack_head_results(\n"," -1, return_labels=False\n"," )\n"," head_out_attr = einops.reduce(\n"," corrupted_grad_head_out * (clean_head_out - corrupted_head_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\",\n"," )\n"," return head_out_attr, labels\n","\n","\n","head_out_attr, head_out_labels = attr_patch_head_out(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," head_out_attr,\n"," y=head_out_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=\"Head Output Attribution Patching\",\n",")\n","sum_head_out_attr = einops.reduce(\n"," head_out_attr,\n"," \"(layer head) pos -> layer head\",\n"," \"sum\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n",")\n","imshow(\n"," sum_head_out_attr,\n"," yaxis=\"Layer\",\n"," xaxis=\"Head Index\",\n"," title=\"Head Output Attribution Patching Sum Over Pos\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" ### Head Activation Patching\n"," Intuitively, a head has three inputs, keys, queries and values. We can patch each of these individually to get a sense for where the important part of each head's input comes from!\n"," As a sanity check, we also do this for the mixed value. The result is a linear map of this (`z @ W_O == result`), so this is the same as patching the output of the head.\n"," We plot both the patch for each head over each position, and summed over position (it tends to be pretty sparse, so the latter is the same)"]},{"cell_type":"code","execution_count":20,"metadata":{},"outputs":[{"data":{"text/markdown":["#### Key Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Query Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Mixed Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","\n","\n","def stack_head_vector_from_cache(\n"," cache, activation_name: Literal[\"q\", \"k\", \"v\", \"z\"]\n",") -> TT[\"layer_and_head_index\", \"batch\", \"pos\", \"d_head\"]:\n"," \"\"\"Stacks the head vectors from the cache from a specific activation (key, query, value or mixed_value (z)) into a single tensor.\"\"\"\n"," stacked_head_vectors = torch.stack(\n"," [cache[activation_name, l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," stacked_head_vectors = einops.rearrange(\n"," stacked_head_vectors,\n"," \"layer batch pos head_index d_head -> (layer head_index) batch pos d_head\",\n"," )\n"," return stacked_head_vectors\n","\n","\n","def attr_patch_head_vector(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n"," activation_name: Literal[\"q\", \"k\", \"v\", \"z\"],\n",") -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_vector = stack_head_vector_from_cache(clean_cache, activation_name)\n"," corrupted_head_vector = stack_head_vector_from_cache(\n"," corrupted_cache, activation_name\n"," )\n"," corrupted_grad_head_vector = stack_head_vector_from_cache(\n"," corrupted_grad_cache, activation_name\n"," )\n"," head_vector_attr = einops.reduce(\n"," corrupted_grad_head_vector * (clean_head_vector - corrupted_head_vector),\n"," \"component batch pos d_head -> component pos\",\n"," \"sum\",\n"," )\n"," return head_vector_attr, labels\n","\n","\n","head_vector_attr_dict = {}\n","for activation_name, activation_name_full in [\n"," (\"k\", \"Key\"),\n"," (\"q\", \"Query\"),\n"," (\"v\", \"Value\"),\n"," (\"z\", \"Mixed Value\"),\n","]:\n"," display(Markdown(f\"#### {activation_name_full} Head Vector Attribution Patching\"))\n"," head_vector_attr_dict[activation_name], head_vector_labels = attr_patch_head_vector(\n"," clean_cache, corrupted_cache, corrupted_grad_cache, activation_name\n"," )\n"," imshow(\n"," head_vector_attr_dict[activation_name],\n"," y=head_vector_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=f\"{activation_name_full} Attribution Patching\",\n"," )\n"," sum_head_vector_attr = einops.reduce(\n"," head_vector_attr_dict[activation_name],\n"," \"(layer head) pos -> layer head\",\n"," \"sum\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n"," )\n"," imshow(\n"," sum_head_vector_attr,\n"," yaxis=\"Layer\",\n"," xaxis=\"Head Index\",\n"," title=f\"{activation_name_full} Attribution Patching Sum Over Pos\",\n"," )"]},{"cell_type":"code","execution_count":21,"metadata":{},"outputs":[{"data":{"text/markdown":["### Head Pattern Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","\n","\n","def stack_head_pattern_from_cache(\n"," cache,\n",") -> TT[\"layer_and_head_index\", \"batch\", \"dest_pos\", \"src_pos\"]:\n"," \"\"\"Stacks the head patterns from the cache into a single tensor.\"\"\"\n"," stacked_head_pattern = torch.stack(\n"," [cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," stacked_head_pattern = einops.rearrange(\n"," stacked_head_pattern,\n"," \"layer batch head_index dest_pos src_pos -> (layer head_index) batch dest_pos src_pos\",\n"," )\n"," return stacked_head_pattern\n","\n","\n","def attr_patch_head_pattern(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"dest_pos\", \"src_pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_pattern = stack_head_pattern_from_cache(clean_cache)\n"," corrupted_head_pattern = stack_head_pattern_from_cache(corrupted_cache)\n"," corrupted_grad_head_pattern = stack_head_pattern_from_cache(corrupted_grad_cache)\n"," head_pattern_attr = einops.reduce(\n"," corrupted_grad_head_pattern * (clean_head_pattern - corrupted_head_pattern),\n"," \"component batch dest_pos src_pos -> component dest_pos src_pos\",\n"," \"sum\",\n"," )\n"," return head_pattern_attr, labels\n","\n","\n","head_pattern_attr, labels = attr_patch_head_pattern(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","\n","plot_attention_attr(\n"," einops.rearrange(\n"," head_pattern_attr,\n"," \"(layer head) dest src -> layer head dest src\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n"," ),\n"," clean_tokens,\n"," index=0,\n"," title=\"Head Pattern Attribution Patching\",\n",")"]},{"cell_type":"code","execution_count":22,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_head_vector_grad_input_from_grad_cache(\n"," grad_cache: ActivationCache, activation_name: Literal[\"q\", \"k\", \"v\"], layer: int\n",") -> TT[\"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," vector_grad = grad_cache[activation_name, layer]\n"," ln_scales = grad_cache[\"scale\", layer, \"ln1\"]\n"," attn_layer_object = model.blocks[layer].attn\n"," if activation_name == \"q\":\n"," W = attn_layer_object.W_Q\n"," elif activation_name == \"k\":\n"," W = attn_layer_object.W_K\n"," elif activation_name == \"v\":\n"," W = attn_layer_object.W_V\n"," else:\n"," raise ValueError(\"Invalid activation name\")\n","\n"," return einsum(\n"," \"batch pos head_index d_head, batch pos, head_index d_model d_head -> batch pos head_index d_model\",\n"," vector_grad,\n"," ln_scales.squeeze(-1),\n"," W,\n"," )\n","\n","\n","def get_stacked_head_vector_grad_input(\n"," grad_cache, activation_name: Literal[\"q\", \"k\", \"v\"]\n",") -> TT[\"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack(\n"," [\n"," get_head_vector_grad_input_from_grad_cache(grad_cache, activation_name, l)\n"," for l in range(model.cfg.n_layers)\n"," ],\n"," dim=0,\n"," )\n","\n","\n","def get_full_vector_grad_input(\n"," grad_cache,\n",") -> TT[\"qkv\", \"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack(\n"," [\n"," get_stacked_head_vector_grad_input(grad_cache, activation_name)\n"," for activation_name in [\"q\", \"k\", \"v\"]\n"," ],\n"," dim=0,\n"," )\n","\n","\n","def attr_patch_head_path(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"qkv\", \"dest_component\", \"src_component\", \"pos\"]:\n"," \"\"\"\n"," Computes the attribution patch along the path between each pair of heads.\n","\n"," Sets this to zero for the path from any late head to any early head\n","\n"," \"\"\"\n"," start_labels = HEAD_NAMES\n"," end_labels = HEAD_NAMES_QKV\n"," full_vector_grad_input = get_full_vector_grad_input(corrupted_grad_cache)\n"," clean_head_result_stack = clean_cache.stack_head_results(-1)\n"," corrupted_head_result_stack = corrupted_cache.stack_head_results(-1)\n"," diff_head_result = einops.rearrange(\n"," clean_head_result_stack - corrupted_head_result_stack,\n"," \"(layer head_index) batch pos d_model -> layer batch pos head_index d_model\",\n"," layer=model.cfg.n_layers,\n"," head_index=model.cfg.n_heads,\n"," )\n"," path_attr = einsum(\n"," \"qkv layer_end batch pos head_end d_model, layer_start batch pos head_start d_model -> qkv layer_end head_end layer_start head_start pos\",\n"," full_vector_grad_input,\n"," diff_head_result,\n"," )\n"," correct_layer_order_mask = (\n"," torch.arange(model.cfg.n_layers)[None, :, None, None, None, None]\n"," > torch.arange(model.cfg.n_layers)[None, None, None, :, None, None]\n"," ).to(path_attr.device)\n"," zero = torch.zeros(1, device=path_attr.device)\n"," path_attr = torch.where(correct_layer_order_mask, path_attr, zero)\n","\n"," path_attr = einops.rearrange(\n"," path_attr,\n"," \"qkv layer_end head_end layer_start head_start pos -> (layer_end head_end qkv) (layer_start head_start) pos\",\n"," )\n"," return path_attr, end_labels, start_labels\n","\n","\n","head_path_attr, end_labels, start_labels = attr_patch_head_path(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," head_path_attr.sum(-1),\n"," y=end_labels,\n"," yaxis=\"Path End (Head Input)\",\n"," x=start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=\"Head Path Attribution Patching\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" This is hard to parse. Here's an experiment with filtering for the most important heads and showing their paths."]},{"cell_type":"code","execution_count":23,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["head_out_values, head_out_indices = head_out_attr.sum(-1).abs().sort(descending=True)\n","line(head_out_values)\n","top_head_indices = head_out_indices[:22].sort().values\n","top_end_indices = []\n","top_end_labels = []\n","top_start_indices = []\n","top_start_labels = []\n","for i in top_head_indices:\n"," i = i.item()\n"," top_start_indices.append(i)\n"," top_start_labels.append(start_labels[i])\n"," for j in range(3):\n"," top_end_indices.append(3 * i + j)\n"," top_end_labels.append(end_labels[3 * i + j])\n","\n","imshow(\n"," head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1),\n"," y=top_end_labels,\n"," yaxis=\"Path End (Head Input)\",\n"," x=top_start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=\"Head Path Attribution Patching (Filtered for Top Heads)\",\n",")"]},{"cell_type":"code","execution_count":24,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["for j, composition_type in enumerate([\"Query\", \"Key\", \"Value\"]):\n"," imshow(\n"," head_path_attr[top_end_indices, :][:, top_start_indices][j::3].sum(-1),\n"," y=top_end_labels[j::3],\n"," yaxis=\"Path End (Head Input)\",\n"," x=top_start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=f\"Head Path to {composition_type} Attribution Patching (Filtered for Top Heads)\",\n"," )"]},{"cell_type":"code","execution_count":25,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["top_head_path_attr = einops.rearrange(\n"," head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1),\n"," \"(head_end qkv) head_start -> qkv head_end head_start\",\n"," qkv=3,\n",")\n","imshow(\n"," top_head_path_attr,\n"," y=[i[:-1] for i in top_end_labels[::3]],\n"," yaxis=\"Path End (Head Input)\",\n"," x=top_start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=f\"Head Path Attribution Patching (Filtered for Top Heads)\",\n"," facet_col=0,\n"," facet_labels=[\"Query\", \"Key\", \"Value\"],\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" Let's now dive into 3 interesting heads: L5H5 (induction head), L8H6 (S-Inhibition Head), L9H9 (Name Mover) and look at their input and output paths (note - Q input means )"]},{"cell_type":"code","execution_count":26,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["interesting_heads = [\n"," 5 * model.cfg.n_heads + 5,\n"," 8 * model.cfg.n_heads + 6,\n"," 9 * model.cfg.n_heads + 9,\n","]\n","interesting_head_labels = [HEAD_NAMES[i] for i in interesting_heads]\n","for head_index, label in zip(interesting_heads, interesting_head_labels):\n"," in_paths = head_path_attr[3 * head_index : 3 * head_index + 3].sum(-1)\n"," out_paths = head_path_attr[:, head_index].sum(-1)\n"," out_paths = einops.rearrange(out_paths, \"(layer_head qkv) -> qkv layer_head\", qkv=3)\n"," all_paths = torch.cat([in_paths, out_paths], dim=0)\n"," all_paths = einops.rearrange(\n"," all_paths,\n"," \"path_type (layer head) -> path_type layer head\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n"," )\n"," imshow(\n"," all_paths,\n"," facet_col=0,\n"," facet_labels=[\n"," \"Query (In)\",\n"," \"Key (In)\",\n"," \"Value (In)\",\n"," \"Query (Out)\",\n"," \"Key (Out)\",\n"," \"Value (Out)\",\n"," ],\n"," title=f\"Input and Output Paths for head {label}\",\n"," yaxis=\"Layer\",\n"," xaxis=\"Head\",\n"," )"]},{"cell_type":"markdown","metadata":{},"source":[" ## Validating Attribution vs Activation Patching\n"," Let's now compare attribution and activation patching. Generally it's a decent approximation! The main place it fails is MLP0 and the residual stream\n"," My fuzzy intuition is that attribution patching works badly for \"big\" things which are poorly modelled as linear approximations, and works well for \"small\" things which are more like incremental changes. Anything involving replacing the embedding is a \"big\" thing, which includes residual streams, and in GPT-2 small MLP0 seems to be used as an \"extended embedding\" (where later layers use MLP0's output instead of the token embedding), so I also count it as big.\n"," See more discussion in the accompanying blog post!\n"]},{"cell_type":"markdown","metadata":{},"source":[" First do some refactoring to make attribution patching more generic. We make an attribution cache, which is an ActivationCache where each element is (clean_act - corrupted_act) * corrupted_grad, so that it's the per-element attribution for each activation. Thanks to linearity, we just compute things by adding stuff up along the relevant dimensions!"]},{"cell_type":"code","execution_count":27,"metadata":{},"outputs":[],"source":["attribution_cache_dict = {}\n","for key in corrupted_grad_cache.cache_dict.keys():\n"," attribution_cache_dict[key] = corrupted_grad_cache.cache_dict[key] * (\n"," clean_cache.cache_dict[key] - corrupted_cache.cache_dict[key]\n"," )\n","attr_cache = ActivationCache(attribution_cache_dict, model)"]},{"cell_type":"markdown","metadata":{},"source":[" By block: For each head we patch the starting residual stream, attention output + MLP output"]},{"cell_type":"code","execution_count":28,"metadata":{},"outputs":[],"source":["str_tokens = model.to_str_tokens(clean_tokens[0])\n","context_length = len(str_tokens)"]},{"cell_type":"code","execution_count":29,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"95a5290e11b64b6a95ef5dd37d027c7a","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_block_act_patch_result = patching.get_act_patch_block_every(\n"," model, corrupted_tokens, clean_cache, ioi_metric\n",")\n","imshow(\n"," every_block_act_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n"," title=\"Activation Patching Per Block\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n",")"]},{"cell_type":"code","execution_count":30,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_block_every(attr_cache):\n"," resid_pre_attr = einops.reduce(\n"," attr_cache.stack_activation(\"resid_pre\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," attn_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"attn_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," mlp_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"mlp_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n","\n"," every_block_attr_patch_result = torch.stack(\n"," [resid_pre_attr, attn_out_attr, mlp_out_attr], dim=0\n"," )\n"," return every_block_attr_patch_result\n","\n","\n","every_block_attr_patch_result = get_attr_patch_block_every(attr_cache)\n","imshow(\n"," every_block_attr_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n"," title=\"Attribution Patching Per Block\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n",")"]},{"cell_type":"code","execution_count":31,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(\n"," y=every_block_attr_patch_result.reshape(3, -1),\n"," x=every_block_act_patch_result.reshape(3, -1),\n"," facet_col=0,\n"," facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n"," title=\"Attribution vs Activation Patching Per Block\",\n"," xaxis=\"Activation Patch\",\n"," yaxis=\"Attribution Patch\",\n"," hover=[\n"," f\"Layer {l}, Position {p}, |{str_tokens[p]}|\"\n"," for l in range(model.cfg.n_layers)\n"," for p in range(context_length)\n"," ],\n"," color=einops.repeat(\n"," torch.arange(model.cfg.n_layers), \"layer -> (layer pos)\", pos=context_length\n"," ),\n"," color_continuous_scale=\"Portland\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow."]},{"cell_type":"code","execution_count":32,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"18b2e6b0985b40cd8c0cd1a16ba62975","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(\n"," model, corrupted_tokens, clean_cache, ioi_metric\n",")\n","imshow(\n"," every_head_all_pos_act_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Activation Patching Per Head (All Pos)\",\n"," xaxis=\"Head\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n",")"]},{"cell_type":"code","execution_count":33,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_all_pos_every(attr_cache):\n"," head_out_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_q_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_k_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_v_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack(\n"," [\n"," head_out_all_pos_attr,\n"," head_q_all_pos_attr,\n"," head_k_all_pos_attr,\n"," head_v_all_pos_attr,\n"," head_pattern_all_pos_attr,\n"," ]\n"," )\n","\n","\n","every_head_all_pos_attr_patch_result = get_attr_patch_attn_head_all_pos_every(\n"," attr_cache\n",")\n","imshow(\n"," every_head_all_pos_attr_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution Patching Per Head (All Pos)\",\n"," xaxis=\"Head\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n",")"]},{"cell_type":"code","execution_count":34,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(\n"," y=every_head_all_pos_attr_patch_result.reshape(5, -1),\n"," x=every_head_all_pos_act_patch_result.reshape(5, -1),\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution vs Activation Patching Per Head (All Pos)\",\n"," xaxis=\"Activation Patch\",\n"," yaxis=\"Attribution Patch\",\n"," include_diag=True,\n"," hover=head_out_labels,\n"," color=einops.repeat(\n"," torch.arange(model.cfg.n_layers),\n"," \"layer -> (layer head)\",\n"," head=model.cfg.n_heads,\n"," ),\n"," color_continuous_scale=\"Portland\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" We see pretty good results in general, but significant errors for heads L5H5 on query and moderate errors for head L10H7 on query and key, and moderate errors for head L11H10 on key. But each of these is fine for pattern and output. My guess is that the problem is that these have pretty saturated attention on a single token, and the linear approximation is thus not great on the attention calculation here, but I'm not sure. When we plot the attention patterns, we do see this!\n"," Note that the axis labels are for the *first* prompt's tokens, but each facet is a different prompt, so this is somewhat inaccurate. In particular, every odd facet has indirect object and subject in the opposite order (IO first). But otherwise everything lines up between the prompts"]},{"cell_type":"code","execution_count":35,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["graph_tok_labels = [\n"," f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))\n","]\n","imshow(\n"," clean_cache[\"pattern\", 5][:, 5],\n"," x=graph_tok_labels,\n"," y=graph_tok_labels,\n"," facet_col=0,\n"," title=\"Attention for Head L5H5\",\n"," facet_name=\"Prompt\",\n",")\n","imshow(\n"," clean_cache[\"pattern\", 10][:, 7],\n"," x=graph_tok_labels,\n"," y=graph_tok_labels,\n"," facet_col=0,\n"," title=\"Attention for Head L10H7\",\n"," facet_name=\"Prompt\",\n",")\n","imshow(\n"," clean_cache[\"pattern\", 11][:, 10],\n"," x=graph_tok_labels,\n"," y=graph_tok_labels,\n"," facet_col=0,\n"," title=\"Attention for Head L11H10\",\n"," facet_name=\"Prompt\",\n",")\n","\n","\n","# [markdown]"]},{"cell_type":"code","execution_count":36,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"06f39489001845849fbc7446a07066f4","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/2160 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_by_pos_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(\n"," model, corrupted_tokens, clean_cache, ioi_metric\n",")\n","every_head_by_pos_act_patch_result = einops.rearrange(\n"," every_head_by_pos_act_patch_result,\n"," \"act_type layer pos head -> act_type (layer head) pos\",\n",")\n","imshow(\n"," every_head_by_pos_act_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Activation Patching Per Head (By Pos)\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer & Head\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n"," y=head_out_labels,\n",")"]},{"cell_type":"code","execution_count":37,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_by_pos_every(attr_cache):\n"," head_out_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_q_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_k_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_v_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer dest_pos head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack(\n"," [\n"," head_out_by_pos_attr,\n"," head_q_by_pos_attr,\n"," head_k_by_pos_attr,\n"," head_v_by_pos_attr,\n"," head_pattern_by_pos_attr,\n"," ]\n"," )\n","\n","\n","every_head_by_pos_attr_patch_result = get_attr_patch_attn_head_by_pos_every(attr_cache)\n","every_head_by_pos_attr_patch_result = einops.rearrange(\n"," every_head_by_pos_attr_patch_result,\n"," \"act_type layer pos head -> act_type (layer head) pos\",\n",")\n","imshow(\n"," every_head_by_pos_attr_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution Patching Per Head (By Pos)\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer & Head\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n"," y=head_out_labels,\n",")"]},{"cell_type":"code","execution_count":38,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(\n"," y=every_head_by_pos_attr_patch_result.reshape(5, -1),\n"," x=every_head_by_pos_act_patch_result.reshape(5, -1),\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution vs Activation Patching Per Head (by Pos)\",\n"," xaxis=\"Activation Patch\",\n"," yaxis=\"Attribution Patch\",\n"," include_diag=True,\n"," hover=[f\"{label} {tok}\" for label in head_out_labels for tok in graph_tok_labels],\n"," color=einops.repeat(\n"," torch.arange(model.cfg.n_layers),\n"," \"layer -> (layer head pos)\",\n"," head=model.cfg.n_heads,\n"," pos=15,\n"," ),\n"," color_continuous_scale=\"Portland\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Factual Knowledge Patching Example\n"," Incomplete, but maybe of interest!\n"," Note that I have better results with the corrupted prompt as having random words rather than Colosseum."]},{"cell_type":"code","execution_count":39,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-xl into HookedTransformer\n","Tokenized prompt: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Paris']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.73 Prob: 95.80% Token: | Paris|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.73\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m95.80\u001b[0m\u001b[1m% Token: | Paris|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.73 Prob: 95.80% Token: | Paris|\n","Top 1th token. Logit: 16.49 Prob: 1.39% Token: | E|\n","Top 2th token. Logit: 14.69 Prob: 0.23% Token: | the|\n","Top 3th token. Logit: 14.58 Prob: 0.21% Token: | É|\n","Top 4th token. Logit: 14.44 Prob: 0.18% Token: | France|\n","Top 5th token. Logit: 14.36 Prob: 0.16% Token: | Mont|\n","Top 6th token. Logit: 13.77 Prob: 0.09% Token: | Le|\n","Top 7th token. Logit: 13.66 Prob: 0.08% Token: | Ang|\n","Top 8th token. Logit: 13.43 Prob: 0.06% Token: | V|\n","Top 9th token. Logit: 13.42 Prob: 0.06% Token: | Stras|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Paris', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Paris'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Tokenized prompt: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Rome']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.02 Prob: 83.70% Token: | Rome|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.02\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m83.70\u001b[0m\u001b[1m% Token: | Rome|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.02 Prob: 83.70% Token: | Rome|\n","Top 1th token. Logit: 17.03 Prob: 4.23% Token: | Naples|\n","Top 2th token. Logit: 16.85 Prob: 3.51% Token: | Pompe|\n","Top 3th token. Logit: 16.14 Prob: 1.73% Token: | Ver|\n","Top 4th token. Logit: 15.87 Prob: 1.32% Token: | Florence|\n","Top 5th token. Logit: 14.77 Prob: 0.44% Token: | Roma|\n","Top 6th token. Logit: 14.68 Prob: 0.40% Token: | Milan|\n","Top 7th token. Logit: 14.66 Prob: 0.39% Token: | ancient|\n","Top 8th token. Logit: 14.37 Prob: 0.29% Token: | Pal|\n","Top 9th token. Logit: 14.30 Prob: 0.27% Token: | Constantinople|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Rome', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Rome'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"}],"source":["gpt2_xl = HookedTransformer.from_pretrained(\"gpt2-xl\")\n","clean_prompt = \"The Eiffel Tower is located in the city of\"\n","clean_answer = \" Paris\"\n","# corrupted_prompt = \"The red brown fox jumps is located in the city of\"\n","corrupted_prompt = \"The Colosseum is located in the city of\"\n","corrupted_answer = \" Rome\"\n","utils.test_prompt(clean_prompt, clean_answer, gpt2_xl)\n","utils.test_prompt(corrupted_prompt, corrupted_answer, gpt2_xl)"]},{"cell_type":"code","execution_count":40,"metadata":{},"outputs":[],"source":["clean_answer_index = gpt2_xl.to_single_token(clean_answer)\n","corrupted_answer_index = gpt2_xl.to_single_token(corrupted_answer)\n","\n","\n","def factual_logit_diff(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return logits[0, -1, clean_answer_index] - logits[0, -1, corrupted_answer_index]"]},{"cell_type":"code","execution_count":41,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 10.634519577026367\n","Corrupted logit diff: -8.988396644592285\n","Clean Metric: tensor(1., device='cuda:0', grad_fn=)\n","Corrupted Metric: tensor(0., device='cuda:0', grad_fn=)\n"]}],"source":["clean_logits, clean_cache = gpt2_xl.run_with_cache(clean_prompt)\n","CLEAN_LOGIT_DIFF_FACTUAL = factual_logit_diff(clean_logits).item()\n","corrupted_logits, _ = gpt2_xl.run_with_cache(corrupted_prompt)\n","CORRUPTED_LOGIT_DIFF_FACTUAL = factual_logit_diff(corrupted_logits).item()\n","\n","\n","def factual_metric(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return (factual_logit_diff(logits) - CORRUPTED_LOGIT_DIFF_FACTUAL) / (\n"," CLEAN_LOGIT_DIFF_FACTUAL - CORRUPTED_LOGIT_DIFF_FACTUAL\n"," )\n","\n","\n","print(\"Clean logit diff:\", CLEAN_LOGIT_DIFF_FACTUAL)\n","print(\"Corrupted logit diff:\", CORRUPTED_LOGIT_DIFF_FACTUAL)\n","print(\"Clean Metric:\", factual_metric(clean_logits))\n","print(\"Corrupted Metric:\", factual_metric(corrupted_logits))"]},{"cell_type":"code","execution_count":42,"metadata":{},"outputs":[],"source":["# corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(gpt2_xl, corrupted_prompt, factual_metric)"]},{"cell_type":"code","execution_count":43,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Corrupted: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n"]}],"source":["clean_tokens = gpt2_xl.to_tokens(clean_prompt)\n","clean_str_tokens = gpt2_xl.to_str_tokens(clean_prompt)\n","corrupted_tokens = gpt2_xl.to_tokens(corrupted_prompt)\n","corrupted_str_tokens = gpt2_xl.to_str_tokens(corrupted_prompt)\n","print(\"Clean:\", clean_str_tokens)\n","print(\"Corrupted:\", corrupted_str_tokens)"]},{"cell_type":"code","execution_count":44,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"b767eef7a3cd49b9b3cb6e5301463f08","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/48 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def act_patch_residual(clean_cache, corrupted_tokens, model: HookedTransformer, metric):\n"," if len(corrupted_tokens.shape) == 2:\n"," corrupted_tokens = corrupted_tokens[0]\n"," residual_patches = torch.zeros(\n"," (model.cfg.n_layers, len(corrupted_tokens)), device=model.cfg.device\n"," )\n","\n"," def residual_hook(resid_pre, hook, layer, pos):\n"," resid_pre[:, pos, :] = clean_cache[\"resid_pre\", layer][:, pos, :]\n"," return resid_pre\n","\n"," for layer in tqdm.tqdm(range(model.cfg.n_layers)):\n"," for pos in range(len(corrupted_tokens)):\n"," patched_logits = model.run_with_hooks(\n"," corrupted_tokens,\n"," fwd_hooks=[\n"," (\n"," f\"blocks.{layer}.hook_resid_pre\",\n"," partial(residual_hook, layer=layer, pos=pos),\n"," )\n"," ],\n"," )\n"," residual_patches[layer, pos] = metric(patched_logits).item()\n"," return residual_patches\n","\n","\n","residual_act_patch = act_patch_residual(\n"," clean_cache, corrupted_tokens, gpt2_xl, factual_metric\n",")\n","\n","imshow(\n"," residual_act_patch,\n"," title=\"Factual Recall Patching (Residual)\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer\",\n"," x=clean_str_tokens,\n",")"]}],"metadata":{"kernelspec":{"display_name":"base","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.11.8"},"orig_nbformat":4,"vscode":{"interpreter":{"hash":"d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"}}},"nbformat":4,"nbformat_minor":2} diff --git a/demos/BERT.ipynb b/demos/BERT.ipynb index 581a6365d..791207f47 100644 --- a/demos/BERT.ipynb +++ b/demos/BERT.ipynb @@ -29,45 +29,70 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Running as a Jupyter notebook - intended for development only!\n" + "Running as a Jupyter notebook - intended for development only!\n", + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_39188/4022418010.py:26: DeprecationWarning:\n", + "\n", + "`magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + "\n", + "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_39188/4022418010.py:27: DeprecationWarning:\n", + "\n", + "`magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + "\n" ] } ], "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "import os\n", + "\n", "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", "DEVELOPMENT_MODE = False\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", "try:\n", " import google.colab\n", + "\n", " IN_COLAB = True\n", " print(\"Running as a Colab notebook\")\n", - " %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n", - " %pip install circuitsvis\n", - " \n", + "\n", " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", " # # Install another version of node that makes PySvelte work way faster\n", " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", "except:\n", " IN_COLAB = False\n", + "\n", + "if not IN_GITHUB and not IN_COLAB:\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", " from IPython import get_ipython\n", "\n", " ipython = get_ipython()\n", " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", " ipython.magic(\"load_ext autoreload\")\n", - " ipython.magic(\"autoreload 2\")" + " ipython.magic(\"autoreload 2\")\n", + "\n", + "if IN_COLAB:\n", + " %pip install transformer_lens\n", + " %pip install circuitsvis" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -81,6 +106,7 @@ "source": [ "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", "import plotly.io as pio\n", + "\n", "if IN_COLAB or not DEVELOPMENT_MODE:\n", " pio.renderers.default = \"colab\"\n", "else:\n", @@ -90,40 +116,41 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import circuitsvis as cv\n", + "\n", "# Testing that the library works\n", "cv.examples.hello(\"Neel\")" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -137,16 +164,16 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -167,26 +194,28 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "WARNING:root:HookedEncoder is still in beta. Please be aware that model preprocessing (e.g. LayerNorm folding) is not yet supported and backward compatibility is not guaranteed.\n" + "WARNING:root:Support for BERT in TransformerLens is currently experimental, until such a time when it has feature parity with HookedTransformer and has been tested on real research tasks. Until then, backward compatibility is not guaranteed. Please see the docs for information on the limitations of the current implementation.\n", + "If using BERT for interpretability research, keep in mind that BERT has some significant architectural differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning that the last LayerNorm in a block cannot be folded.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Moving model to device: cpu\n", + "Moving model to device: mps\n", "Loaded pretrained model bert-base-cased into HookedTransformer\n" ] } ], "source": [ + "# NBVAL_IGNORE_OUTPUT\n", "bert = HookedEncoder.from_pretrained(\"bert-base-cased\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")" ] @@ -201,7 +230,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -213,7 +242,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -230,7 +259,7 @@ "prediction = tokenizer.decode(logprobs.argmax(dim=-1).item())\n", "\n", "print(f\"Prompt: {prompt}\")\n", - "print(f\"Prediction: \\\"{prediction}\\\"\")" + "print(f'Prediction: \"{prediction}\"')" ] }, { @@ -258,7 +287,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.11.8" }, "orig_nbformat": 4 }, diff --git a/demos/Grokking_Demo.ipynb b/demos/Grokking_Demo.ipynb index 7e3792095..473d7ca82 100644 --- a/demos/Grokking_Demo.ipynb +++ b/demos/Grokking_Demo.ipynb @@ -53,13 +53,14 @@ ], "source": [ "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "import os\n", + "\n", "DEVELOPMENT_MODE = True\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", "try:\n", " import google.colab\n", " IN_COLAB = True\n", " print(\"Running as a Colab notebook\")\n", - " %pip install transformer-lens\n", - " %pip install circuitsvis\n", " \n", " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", " # # Install another version of node that makes PySvelte work way faster\n", @@ -73,7 +74,11 @@ " ipython = get_ipython()\n", " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", " ipython.magic(\"load_ext autoreload\")\n", - " ipython.magic(\"autoreload 2\")" + " ipython.magic(\"autoreload 2\")\n", + " \n", + "if IN_COLAB or IN_GITHUB:\n", + " %pip install transformer_lens\n", + " %pip install circuitsvis" ] }, { @@ -154,7 +159,10 @@ " HookedRootModule,\n", " HookPoint,\n", ") # Hooking utilities\n", - "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache" + "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache\n", + "\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" ] }, { @@ -281,7 +289,7 @@ } ], "source": [ - "dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).cuda()\n", + "dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).to(device)\n", "print(dataset[:5])\n", "print(dataset.shape)" ] @@ -386,7 +394,7 @@ " d_vocab_out=p,\n", " n_ctx=3,\n", " init_weights=True,\n", - " device=\"cuda\",\n", + " device=device,\n", " seed = 999,\n", ")" ] @@ -1645,7 +1653,7 @@ " fourier_basis_names.append(f\"Sin {freq}\")\n", " fourier_basis.append(torch.cos(torch.arange(p)*2 * torch.pi * freq / p))\n", " fourier_basis_names.append(f\"Cos {freq}\")\n", - "fourier_basis = torch.stack(fourier_basis, dim=0).cuda()\n", + "fourier_basis = torch.stack(fourier_basis, dim=0).to(device)\n", "fourier_basis = fourier_basis/fourier_basis.norm(dim=-1, keepdim=True)\n", "imshow(fourier_basis, xaxis=\"Input\", yaxis=\"Component\", y=fourier_basis_names)" ] @@ -2394,7 +2402,7 @@ } ], "source": [ - "neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp).cuda()\n", + "neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp).to(device)\n", "for freq in range(0, p//2):\n", " for x in [0, 2*(freq+1) - 1, 2*(freq+1)]:\n", " for y in [0, 2*(freq+1) - 1, 2*(freq+1)]:\n", @@ -2993,7 +3001,7 @@ " a = torch.arange(p)[:, None, None]\n", " b = torch.arange(p)[None, :, None]\n", " c = torch.arange(p)[None, None, :]\n", - " cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).cuda()\n", + " cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).to(device)\n", " cube_predicted_logits /= cube_predicted_logits.norm()\n", " coses[freq] = cube_predicted_logits" ] @@ -3124,7 +3132,7 @@ " a = torch.arange(p)[:, None, None]\n", " b = torch.arange(p)[None, :, None]\n", " c = torch.arange(p)[None, None, :]\n", - " cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).cuda()\n", + " cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).to(device)\n", " cube_predicted_logits /= cube_predicted_logits.norm()\n", " cos_cube.append(cube_predicted_logits)\n", "cos_cube = torch.stack(cos_cube, dim=0)\n", @@ -3486,11 +3494,11 @@ "a = torch.arange(p)[:, None]\n", "b = torch.arange(p)[None, :]\n", "for freq in key_freqs:\n", - " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)\n", " cos_apb_vec /= cos_apb_vec.norm()\n", " cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n", - " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)\n", " sin_apb_vec /= sin_apb_vec.norm()\n", " sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n", @@ -3555,11 +3563,11 @@ " a = torch.arange(p)[:, None]\n", " b = torch.arange(p)[None, :]\n", " for freq in key_freqs:\n", - " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)\n", " cos_apb_vec /= cos_apb_vec.norm()\n", " cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n", - " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)\n", " sin_apb_vec /= sin_apb_vec.norm()\n", " sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n", @@ -3718,11 +3726,11 @@ "a = torch.arange(p)[:, None]\n", "b = torch.arange(p)[None, :]\n", "for freq in key_freqs:\n", - " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)\n", " cos_apb_vec /= cos_apb_vec.norm()\n", " cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n", - " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)\n", " sin_apb_vec /= sin_apb_vec.norm()\n", " sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n", @@ -3765,11 +3773,11 @@ " a = torch.arange(p)[:, None]\n", " b = torch.arange(p)[None, :]\n", " for freq in key_freqs:\n", - " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)\n", " cos_apb_vec /= cos_apb_vec.norm()\n", " cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n", - " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)\n", " sin_apb_vec /= sin_apb_vec.norm()\n", " sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n", diff --git a/demos/Main_Demo.ipynb b/demos/Main_Demo.ipynb index c871a6bd8..b2f89b695 100644 --- a/demos/Main_Demo.ipynb +++ b/demos/Main_Demo.ipynb @@ -74,8 +74,7 @@ " ip.extension_manager.load('autoreload')\n", " %autoreload 2\n", " \n", - "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", - "IN_GITHUB = True\n" + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n" ] }, { diff --git a/makefile b/makefile index b786aa209..17d583dae 100644 --- a/makefile +++ b/makefile @@ -18,8 +18,9 @@ docstring-test: poetry run pytest transformer_lens/ notebook-test: - poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Main_Demo.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/BERT.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Exploratory_Analysis_Demo.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Main_Demo.ipynb test: make unit-test From 07e8f386c49a6fc60c9a7785aaf3c5dcc3018e26 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Sat, 27 Apr 2024 20:00:45 +0200 Subject: [PATCH 66/73] removed import --- transformer_lens/components/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/components/attention.py b/transformer_lens/components/attention.py index 9654b6f09..4b81d573c 100644 --- a/transformer_lens/components/attention.py +++ b/transformer_lens/components/attention.py @@ -2,7 +2,7 @@ This module contains all the component :class:`Attention`. """ -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Union import torch import torch.nn as nn From 517fab55b86f2e5d8c55d904e8cdc5069ddc4aa1 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Sat, 27 Apr 2024 20:07:19 +0200 Subject: [PATCH 67/73] cleaned imports --- transformer_lens/components/abstract_attention.py | 5 +---- transformer_lens/components/attention.py | 1 - transformer_lens/components/bert_block.py | 4 +--- transformer_lens/components/moe.py | 2 +- transformer_lens/components/transformer_block.py | 6 ++---- 5 files changed, 5 insertions(+), 13 deletions(-) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 245f8c2e8..076f2a063 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -15,10 +15,7 @@ from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry -from transformer_lens.utils import ( - get_offset_position_ids, -) - +from transformer_lens.utils import get_offset_position_ids if is_bitsandbytes_available(): import bitsandbytes as bnb diff --git a/transformer_lens/components/attention.py b/transformer_lens/components/attention.py index 4b81d573c..c5363fd86 100644 --- a/transformer_lens/components/attention.py +++ b/transformer_lens/components/attention.py @@ -11,7 +11,6 @@ from transformer_lens.components import AbstractAttention from transformer_lens.HookedTransformerConfig import HookedTransformerConfig - if is_bitsandbytes_available(): from bitsandbytes.nn.modules import Params4bit diff --git a/transformer_lens/components/bert_block.py b/transformer_lens/components/bert_block.py index aa01e21e7..3740d914b 100644 --- a/transformer_lens/components/bert_block.py +++ b/transformer_lens/components/bert_block.py @@ -7,13 +7,11 @@ import torch import torch.nn as nn from jaxtyping import Float -from transformer_lens.utils import ( - repeat_along_head_dimension, -) from transformer_lens.components import MLP, Attention, LayerNorm from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.utils import repeat_along_head_dimension class BertBlock(nn.Module): diff --git a/transformer_lens/components/moe.py b/transformer_lens/components/moe.py index 089f3eab8..60598dc46 100644 --- a/transformer_lens/components/moe.py +++ b/transformer_lens/components/moe.py @@ -6,7 +6,7 @@ from fancy_einsum import einsum from jaxtyping import Float -from transformer_lens.components import GatedMLP, MLP +from transformer_lens.components import MLP, GatedMLP from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig diff --git a/transformer_lens/components/transformer_block.py b/transformer_lens/components/transformer_block.py index b8281c63b..8980f9e8c 100644 --- a/transformer_lens/components/transformer_block.py +++ b/transformer_lens/components/transformer_block.py @@ -8,24 +8,22 @@ import torch import torch.nn as nn from jaxtyping import Float, Int -from transformer_lens.utils import ( - repeat_along_head_dimension, -) from transformer_lens.components import ( MLP, - MoE, Attention, GatedMLP, GroupedQueryAttention, LayerNorm, LayerNormPre, + MoE, RMSNorm, RMSNormPre, ) from transformer_lens.hook_points import HookPoint from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry +from transformer_lens.utils import repeat_along_head_dimension # Transformer Block From 4053079442712e9a9ab1c644ff5c561ef8b3f8cb Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Sat, 27 Apr 2024 20:11:09 +0200 Subject: [PATCH 68/73] ran black --- transformer_lens/components/abstract_attention.py | 15 ++++++++++----- transformer_lens/components/attention.py | 1 + transformer_lens/components/gated_mlp.py | 2 +- .../components/grouped_query_attention.py | 3 +-- transformer_lens/components/mlp.py | 2 +- transformer_lens/components/moe.py | 2 +- 6 files changed, 15 insertions(+), 10 deletions(-) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 076f2a063..666bf169d 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -21,6 +21,7 @@ import bitsandbytes as bnb from bitsandbytes.nn.modules import Params4bit + class AbstractAttention(ABC, nn.Module): alibi: Union[torch.Tensor, None] @@ -57,8 +58,8 @@ def __init__( self.W_O = nn.Parameter( torch.empty(self.cfg.n_heads, self.cfg.d_head, self.cfg.d_model, dtype=cfg.dtype) ) - self.W_K: Params4bit|nn.Parameter = abstract_attribute() - self.W_V: Params4bit|nn.Parameter = abstract_attribute() + self.W_K: Params4bit | nn.Parameter = abstract_attribute() + self.W_V: Params4bit | nn.Parameter = abstract_attribute() self.b_Q = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)) self.b_K: nn.Parameter = abstract_attribute() @@ -206,7 +207,9 @@ def forward( # only recompute when necessary to increase efficiency. if self.alibi is None or key_ctx > self.alibi.size(-1): - self.alibi = AbstractAttention.create_alibi_bias(self.cfg.n_heads, key_ctx, self.cfg.device) + self.alibi = AbstractAttention.create_alibi_bias( + self.cfg.n_heads, key_ctx, self.cfg.device + ) attn_scores += self.alibi[ :, :query_ctx, :key_ctx @@ -636,7 +639,9 @@ def create_alibi_bias( The ALiBi bias that should be added to the attention scores before the softmax. """ # Create the slope matrix - slope: Float[torch.Tensor, "query key"] = AbstractAttention.create_alibi_slope(n_ctx, device) + slope: Float[torch.Tensor, "query key"] = AbstractAttention.create_alibi_slope( + n_ctx, device + ) # Create the scalar multiplier for each head. multipliers: Float[torch.Tensor, "head_idx"] = AbstractAttention.create_alibi_multipliers( @@ -646,4 +651,4 @@ def create_alibi_bias( # The ALiBi bias is then m * slope_matrix alibi_bias = torch.einsum("ij,k->kij", slope, multipliers) - return alibi_bias \ No newline at end of file + return alibi_bias diff --git a/transformer_lens/components/attention.py b/transformer_lens/components/attention.py index c5363fd86..c463361c5 100644 --- a/transformer_lens/components/attention.py +++ b/transformer_lens/components/attention.py @@ -14,6 +14,7 @@ if is_bitsandbytes_available(): from bitsandbytes.nn.modules import Params4bit + # Attention class Attention(AbstractAttention): def __init__( diff --git a/transformer_lens/components/gated_mlp.py b/transformer_lens/components/gated_mlp.py index 49443cdf5..886108d74 100644 --- a/transformer_lens/components/gated_mlp.py +++ b/transformer_lens/components/gated_mlp.py @@ -140,4 +140,4 @@ def forward( self.W_out, ) + self.b_out - ) \ No newline at end of file + ) diff --git a/transformer_lens/components/grouped_query_attention.py b/transformer_lens/components/grouped_query_attention.py index d4fe95276..6c94e00a2 100644 --- a/transformer_lens/components/grouped_query_attention.py +++ b/transformer_lens/components/grouped_query_attention.py @@ -1,4 +1,3 @@ - from typing import Dict, Tuple, Union import torch @@ -188,4 +187,4 @@ def calculate_z_scores( Float[torch.Tensor, "batch head_index query_pos key_pos"]: The z scores. """ v = torch.repeat_interleave(v, dim=2, repeats=self.repeat_kv_heads) - return super().calculate_z_scores(v, pattern) \ No newline at end of file + return super().calculate_z_scores(v, pattern) diff --git a/transformer_lens/components/mlp.py b/transformer_lens/components/mlp.py index 91dbbd95e..c9a18fc3f 100644 --- a/transformer_lens/components/mlp.py +++ b/transformer_lens/components/mlp.py @@ -76,4 +76,4 @@ def forward( self.W_out, ) + self.b_out - ) \ No newline at end of file + ) diff --git a/transformer_lens/components/moe.py b/transformer_lens/components/moe.py index 60598dc46..01f0298c7 100644 --- a/transformer_lens/components/moe.py +++ b/transformer_lens/components/moe.py @@ -59,4 +59,4 @@ def forward( # accumulate the weighted outputs from the expert results[batch] += weights[batch, pos, expert, None, None] * expert_mlp(x[batch]) - return results \ No newline at end of file + return results From 2d1ee77b7ef55fbb95d7b33480601f4ff1e830b6 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Sat, 27 Apr 2024 20:18:38 +0200 Subject: [PATCH 69/73] updated doc string --- transformer_lens/components/abstract_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 666bf169d..f18c3abcf 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -535,12 +535,12 @@ def create_alibi_slope( Examples: - >>> Attention.create_alibi_slope(3) + >>> AbstractAttention.create_alibi_slope(3) tensor([[ 0., 0., 0.], [-1., 0., 0.], [-2., -1., 0.]]) - >>> Attention.create_alibi_slope(4) + >>> AbstractAttention.create_alibi_slope(4) tensor([[ 0., 0., 0., 0.], [-1., 0., 0., 0.], [-2., -1., 0., 0.], From 38b4ddab954f15c12237546d9d199632b3714f57 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Sat, 27 Apr 2024 20:53:46 +0200 Subject: [PATCH 70/73] finished fixing docstring --- transformer_lens/components/abstract_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index f18c3abcf..2de38518f 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -579,10 +579,10 @@ def create_alibi_multipliers( Examples: - >>> Attention.create_alibi_multipliers(8) + >>> AbstractAttention.create_alibi_multipliers(8) tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039]) - >>> Attention.create_alibi_multipliers(16) + >>> AbstractAttention.create_alibi_multipliers(16) tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312, 0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039]) @@ -620,7 +620,7 @@ def create_alibi_bias( Examples: - >>> Attention.create_alibi_bias(2, 4, torch.device('cpu')) + >>> AbstractAttention.create_alibi_bias(2, 4, torch.device('cpu')) tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000], [-0.0625, 0.0000, 0.0000, 0.0000], [-0.1250, -0.0625, 0.0000, 0.0000], From e2e0578eae03a11bb68f3806ba99216fc2383b0c Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Sat, 27 Apr 2024 21:00:09 +0200 Subject: [PATCH 71/73] fixed mypi error --- transformer_lens/components/abstract_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 2de38518f..c6fc0e5e8 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -58,8 +58,8 @@ def __init__( self.W_O = nn.Parameter( torch.empty(self.cfg.n_heads, self.cfg.d_head, self.cfg.d_model, dtype=cfg.dtype) ) - self.W_K: Params4bit | nn.Parameter = abstract_attribute() - self.W_V: Params4bit | nn.Parameter = abstract_attribute() + self.W_K = abstract_attribute() + self.W_V = abstract_attribute() self.b_Q = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)) self.b_K: nn.Parameter = abstract_attribute() From ca6b8db53a07e7b729112795861fcd113bfe6cac Mon Sep 17 00:00:00 2001 From: Connor Kissane <67170576+ckkissane@users.noreply.github.com> Date: Tue, 30 Apr 2024 08:20:23 -0400 Subject: [PATCH 72/73] HookedSAETransformer (#536) * implement HookedSAETransformer * clean up imports * apply format * only recompute error if use_error_term * add tests * run format * fix import * match to hooks API * improve doc strings * improve demo * address Arthur feedback * try to fix indent: * try to fix indent again * change doc code block --- README.md | 3 +- demos/HookedSAETransformerDemo.ipynb | 18616 ++++++++++++++++++++ tests/unit/test_hooked_sae.py | 191 + tests/unit/test_hooked_sae_transformer.py | 515 + transformer_lens/HookedSAE.py | 118 + transformer_lens/HookedSAEConfig.py | 64 + transformer_lens/HookedSAETransformer.py | 290 + transformer_lens/__init__.py | 3 + 8 files changed, 19799 insertions(+), 1 deletion(-) create mode 100644 demos/HookedSAETransformerDemo.ipynb create mode 100644 tests/unit/test_hooked_sae.py create mode 100644 tests/unit/test_hooked_sae_transformer.py create mode 100644 transformer_lens/HookedSAE.py create mode 100644 transformer_lens/HookedSAEConfig.py create mode 100644 transformer_lens/HookedSAETransformer.py diff --git a/README.md b/README.md index 3f65c881d..ab7d11396 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ TransformerLens lets you load in 50+ different open source language models, and activations of the model to you. You can cache any internal activation in the model, and add in functions to edit, remove or replace these activations as the model runs. -~~ [OCTOBER SURVEY HERE](https://forms.gle/bw7U3PfioacDtFmT8) ~~ +The library also now supports mechanistic interpretability with SAEs (sparse autoencoders)! With [HookedSAETransformer](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/hooked-sae-transformer/demos/HookedSAETransformerDemo.ipynb), you can splice in SAEs during inference and cache + intervene on SAE activations. We recommend [SAELens](https://github.com/jbloomAus/SAELens) (built on top of TransformerLens) for training SAEs. ## Quick Start @@ -51,6 +51,7 @@ logits, activations = model.run_with_cache("Hello World") * [Introduction to the Library and Mech Interp](https://arena-ch1-transformers.streamlit.app/[1.2]_Intro_to_Mech_Interp) * [Demo of Main TransformerLens Features](https://neelnanda.io/transformer-lens-demo) +* [Demo of HookedSAETransformer Features](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/hooked-sae-transformer/demos/HookedSAETransformerDemo.ipynb) ## Gallery diff --git a/demos/HookedSAETransformerDemo.ipynb b/demos/HookedSAETransformerDemo.ipynb new file mode 100644 index 000000000..77d0d7c37 --- /dev/null +++ b/demos/HookedSAETransformerDemo.ipynb @@ -0,0 +1,18616 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " \"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# HookedSAETransformer Demo" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "HookedSAETransformer is a lightweight extension of HookedTransformer that allows you to \"splice in\" Sparse Autoencoders. This makes it easy to do exploratory analysis such as: running inference with SAEs attached, caching SAE feature activations, and intervening on SAE activations with hooks.\n", + "\n", + "I (Connor Kissane) implemented this to accelerate research on [Attention SAEs](https://www.lesswrong.com/posts/DtdzGwFh9dCfsekZZ/sparse-autoencoders-work-on-attention-layer-outputs) based on suggestions from Arthur Conmy and Neel Nanda, and found that it was well worth the time and effort. I hope other researchers will also find the library useful! This notebook demonstrates how it works and how to use it." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running as a Jupyter notebook - intended for development only!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_10435/2185356984.py:16: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + " ipython.magic(\"load_ext autoreload\")\n", + "/tmp/ipykernel_10435/2185356984.py:17: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + " ipython.magic(\"autoreload 2\")\n" + ] + } + ], + "source": [ + "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "DEVELOPMENT_MODE = False\n", + "try:\n", + " import google.colab\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + " %pip install git+https://github.com/ckkissane/TransformerLens@hooked-sae-transformer\n", + " \n", + "except:\n", + " IN_COLAB = False\n", + " print(\"Running as a Jupyter notebook - intended for development only!\")\n", + " from IPython import get_ipython\n", + "\n", + " ipython = get_ipython()\n", + " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", + " ipython.magic(\"load_ext autoreload\")\n", + " ipython.magic(\"autoreload 2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import transformer_lens.utils as utils\n", + "\n", + "import plotly.express as px\n", + "import tqdm\n", + "from functools import partial\n", + "import einops\n", + "import plotly.graph_objects as go\n", + "\n", + "update_layout_set = {\n", + " \"xaxis_range\", \"yaxis_range\", \"hovermode\", \"xaxis_title\", \"yaxis_title\", \"colorbar\", \"colorscale\", \"coloraxis\",\n", + " \"title_x\", \"bargap\", \"bargroupgap\", \"xaxis_tickformat\", \"yaxis_tickformat\", \"title_y\", \"legend_title_text\", \"xaxis_showgrid\",\n", + " \"xaxis_gridwidth\", \"xaxis_gridcolor\", \"yaxis_showgrid\", \"yaxis_gridwidth\"\n", + "}\n", + "\n", + "def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n", + " if isinstance(tensor, list):\n", + " tensor = torch.stack(tensor)\n", + " kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}\n", + " kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}\n", + " if \"facet_labels\" in kwargs_pre:\n", + " facet_labels = kwargs_pre.pop(\"facet_labels\")\n", + " else:\n", + " facet_labels = None\n", + " if \"color_continuous_scale\" not in kwargs_pre:\n", + " kwargs_pre[\"color_continuous_scale\"] = \"RdBu\"\n", + " fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={\"x\":xaxis, \"y\":yaxis}, **kwargs_pre).update_layout(**kwargs_post)\n", + " if facet_labels:\n", + " for i, label in enumerate(facet_labels):\n", + " fig.layout.annotations[i]['text'] = label\n", + "\n", + " fig.show(renderer)\n", + "\n", + "def scatter(x, y, xaxis=\"\", yaxis=\"\", caxis=\"\", renderer=None, return_fig=False, **kwargs):\n", + " x = utils.to_numpy(x)\n", + " y = utils.to_numpy(y)\n", + " fig = px.scatter(y=y, x=x, labels={\"x\":xaxis, \"y\":yaxis, \"color\":caxis}, **kwargs)\n", + " if return_fig:\n", + " return fig\n", + " fig.show(renderer)\n", + "\n", + "from typing import List\n", + "def show_avg_logit_diffs(x_axis: List[str], per_prompt_logit_diffs: List[torch.tensor]):\n", + "\n", + "\n", + " y_data = [per_prompt_logit_diff.mean().item() for per_prompt_logit_diff in per_prompt_logit_diffs]\n", + " error_y_data = [per_prompt_logit_diff.std().item() for per_prompt_logit_diff in per_prompt_logit_diffs] \n", + "\n", + " fig = go.Figure(data=[go.Bar(\n", + " x=x_axis,\n", + " y=y_data,\n", + " error_y=dict(\n", + " type='data', # specifies that the actual values are given\n", + " array=error_y_data, # the magnitudes of the errors\n", + " visible=True # make error bars visible\n", + " ),\n", + " )])\n", + "\n", + " # Customize layout\n", + " fig.update_layout(title_text=f'Logit Diff after Interventions',\n", + " xaxis_title_text='Intervention',\n", + " yaxis_title_text='Logit diff',\n", + " plot_bgcolor='white')\n", + "\n", + " # Show the figure\n", + " fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "torch.set_grad_enabled(False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Loading and Running Models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Just like a [HookedTransformer](https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Loading-and-Running-Models), we can load in any model that's supported in TransformerLens with the `HookedSAETransformer.from_pretrained(MODEL_NAME)`. In this demo we'll use GPT-2 small." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using pad_token, but it is not set yet.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded pretrained model gpt2-small into HookedTransformer\n", + "Moving model to device: cuda\n" + ] + } + ], + "source": [ + "from transformer_lens import HookedSAETransformer\n", + "model: HookedSAETransformer = HookedSAETransformer.from_pretrained(\"gpt2-small\").to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By default HookedSAETransformer will behave exactly like a HookedTransformer. We'll explore the main features of HookedSAETransformer on the classic IOI task, so let's first sanity check that GPT2-small can do the IOI task without any SAEs attached:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['When John and Mary went to the shops, Mary gave the bag to', 'When John and Mary went to the shops, John gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']\n", + "[(' John', ' Mary'), (' Mary', ' John'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]\n" + ] + } + ], + "source": [ + "prompt_format = [\n", + " \"When John and Mary went to the shops,{} gave the bag to\",\n", + " \"When Tom and James went to the park,{} gave the ball to\",\n", + " \"When Dan and Sid went to the shops,{} gave an apple to\",\n", + " \"After Martin and Amy went to the park,{} gave a drink to\",\n", + "]\n", + "names = [\n", + " (\" John\", \" Mary\",),\n", + " (\" Tom\", \" James\"),\n", + " (\" Dan\", \" Sid\"),\n", + " (\" Martin\", \" Amy\"),\n", + "]\n", + "# List of prompts\n", + "prompts = []\n", + "# List of answers, in the format (correct, incorrect)\n", + "answers = []\n", + "# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)\n", + "answer_tokens = []\n", + "for i in range(len(prompt_format)):\n", + " for j in range(2):\n", + " answers.append((names[i][j], names[i][1 - j]))\n", + " answer_tokens.append(\n", + " (\n", + " model.to_single_token(answers[-1][0]),\n", + " model.to_single_token(answers[-1][1]),\n", + " )\n", + " )\n", + " # Insert the *incorrect* answer to the prompt, making the correct answer the indirect object.\n", + " prompts.append(prompt_format[i].format(answers[-1][1]))\n", + "answer_tokens = torch.tensor(answer_tokens).to(device)\n", + "print(prompts)\n", + "print(answers)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original average logit diff: 3.5518884658813477\n", + "Original per prompt logit diff: tensor([3.2016, 3.3367, 2.7095, 3.7975, 1.7204, 5.2812, 2.6008, 5.7674],\n", + " device='cuda:0')\n" + ] + } + ], + "source": [ + "def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):\n", + " # Only the final logits are relevant for the answer\n", + " final_logits = logits[:, -1, :]\n", + " answer_logits = final_logits.gather(dim=-1, index=answer_tokens)\n", + " answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]\n", + " if per_prompt:\n", + " return answer_logit_diff\n", + " else:\n", + " return answer_logit_diff.mean()\n", + " \n", + "tokens = model.to_tokens(prompts, prepend_bos=True)\n", + "original_logits, cache = model.run_with_cache(tokens)\n", + "original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)\n", + "print(f\"Original average logit diff: {original_average_logit_diff}\")\n", + "original_per_prompt_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)\n", + "print(f\"Original per prompt logit diff: {original_per_prompt_logit_diff}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# HookedSAEs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to use the key features of HookedSAETransformer, we first need to load in SAEs.\n", + "\n", + "HookedSAE is an SAE class we've implemented to have TransformerLens hooks around the SAE activations. While we will use it out of the box, it is designed to be hackable: you can copy and paste the HookedSAE class into a notebook and completely change the architecture / hook names, and as long as it reconstructs the activations, it should still work.\n", + "\n", + "You can initialize a HookedSAE with a HookedSAEConfig:\n", + "```\n", + "cfg = HookedSAEConfig(\n", + " d_sae (int): The size of the dictionary.\n", + " d_in (int): The dimension of the input activations for the SAE\n", + " hook_name (str): The hook name of the activation the SAE was trained on (eg. blocks.0.attn.hook_z)\n", + ")\n", + "hooked_sae = HookedSAE(cfg)\n", + "```\n", + "\n", + "Note you'll likely have to write some basic conversion code to match configs / state dicts to the HookedSAE when loading in an open sourced SAE (eg from HuggingFace). We'll use our GPT-2 Small [Attention SAEs](https://www.alignmentforum.org/posts/FSTRedtjuHa4Gfdbr/attention-saes-scale-to-gpt-2-small) to demonstrate. For convenience, we'll load in all of our attention SAEs from HuggingFace, convert them to HookedSAEs, and store them in a dictionary that maps each hook_name (str) to the corresponding HookedSAE.\n", + "\n", + "
\n", + "\n", + "Later we'll show how to add HookedSAEs to the HookedSAETransformer (replacing model activations with their SAE reconstructions). When you add a HookedSAE, HookedSAETransformer just treats this a black box that takes some activation as an input, and outputs a tensor of the same shape. \n", + "\n", + "With this in mind, the HookedSAE is designed to be simple and hackable. Think of it as a convenient default class that you can copy and edit. As long as it takes a TransformerLens activation as input, and outputs a tensor of the same shape, you should be able to add it to your HookedSAETransformer.\n", + "\n", + "You probably don't even need to use the HookedSAE class, although it's recommended. The sae can be any pytorch module that takes in some activation at hook_name and outputs a tensor of the same shape. The two assumptions that HookedSAETransformer makes when adding SAEs are:\n", + "1. The SAE class has a cfg attribute, sae.cfg.hook_name (str), for the activation that the SAE was trained to reconstruct (in TransformerLens notation e.g. 'blocks.0.attn.hook_z')\n", + "2. The SAE takes that activation as input, and outputs a tensor of the same shape.\n", + "\n", + "The main benefit of HookedSAE is that it's a subclass of HookedRootModule, so we can add hooks to SAE activations. This makes it easy to leverage existing TL functionality like run_with_cache and run_with_hooks with SAEs.\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['blocks.0.attn.hook_z', 'blocks.1.attn.hook_z', 'blocks.2.attn.hook_z', 'blocks.3.attn.hook_z', 'blocks.4.attn.hook_z', 'blocks.5.attn.hook_z', 'blocks.6.attn.hook_z', 'blocks.7.attn.hook_z', 'blocks.8.attn.hook_z', 'blocks.9.attn.hook_z', 'blocks.10.attn.hook_z', 'blocks.11.attn.hook_z'])\n" + ] + } + ], + "source": [ + "from transformer_lens import HookedSAE, HookedSAEConfig\n", + "from transformer_lens.utils import download_file_from_hf\n", + "def attn_sae_cfg_to_hooked_sae_cfg(attn_sae_cfg):\n", + " new_cfg = {\n", + " \"d_sae\": attn_sae_cfg[\"dict_size\"],\n", + " \"d_in\": attn_sae_cfg[\"act_size\"],\n", + " \"hook_name\": attn_sae_cfg[\"act_name\"],\n", + " }\n", + " return HookedSAEConfig.from_dict(new_cfg)\n", + "\n", + "auto_encoder_runs = [\n", + " \"gpt2-small_L0_Hcat_z_lr1.20e-03_l11.80e+00_ds24576_bs4096_dc1.00e-06_rsanthropic_rie25000_nr4_v9\",\n", + " \"gpt2-small_L1_Hcat_z_lr1.20e-03_l18.00e-01_ds24576_bs4096_dc1.00e-06_rsanthropic_rie25000_nr4_v5\",\n", + " \"gpt2-small_L2_Hcat_z_lr1.20e-03_l11.00e+00_ds24576_bs4096_dc1.00e-06_rsanthropic_rie25000_nr4_v4\",\n", + " \"gpt2-small_L3_Hcat_z_lr1.20e-03_l19.00e-01_ds24576_bs4096_dc1.00e-06_rsanthropic_rie25000_nr4_v9\",\n", + " \"gpt2-small_L4_Hcat_z_lr1.20e-03_l11.10e+00_ds24576_bs4096_dc1.00e-06_rsanthropic_rie25000_nr4_v7\",\n", + " \"gpt2-small_L5_Hcat_z_lr1.20e-03_l11.00e+00_ds49152_bs4096_dc1.00e-06_rsanthropic_rie25000_nr4_v9\",\n", + " \"gpt2-small_L6_Hcat_z_lr1.20e-03_l11.10e+00_ds24576_bs4096_dc1.00e-06_rsanthropic_rie25000_nr4_v9\",\n", + " \"gpt2-small_L7_Hcat_z_lr1.20e-03_l11.10e+00_ds49152_bs4096_dc1.00e-06_rsanthropic_rie25000_nr4_v9\",\n", + " \"gpt2-small_L8_Hcat_z_lr1.20e-03_l11.30e+00_ds24576_bs4096_dc1.00e-05_rsanthropic_rie25000_nr4_v6\",\n", + " \"gpt2-small_L9_Hcat_z_lr1.20e-03_l11.20e+00_ds24576_bs4096_dc1.00e-06_rsanthropic_rie25000_nr4_v9\",\n", + " \"gpt2-small_L10_Hcat_z_lr1.20e-03_l11.30e+00_ds24576_bs4096_dc1.00e-05_rsanthropic_rie25000_nr4_v9\",\n", + " \"gpt2-small_L11_Hcat_z_lr1.20e-03_l13.00e+00_ds24576_bs4096_dc3.16e-06_rsanthropic_rie25000_nr4_v9\",\n", + "]\n", + "\n", + "hf_repo = \"ckkissane/attn-saes-gpt2-small-all-layers\"\n", + "\n", + "hook_name_to_sae = {}\n", + "for auto_encoder_run in auto_encoder_runs:\n", + " attn_sae_cfg = download_file_from_hf(hf_repo, f\"{auto_encoder_run}_cfg.json\")\n", + " cfg = attn_sae_cfg_to_hooked_sae_cfg(attn_sae_cfg)\n", + " \n", + " state_dict = download_file_from_hf(hf_repo, f\"{auto_encoder_run}.pt\", force_is_torch=True)\n", + " \n", + " hooked_sae = HookedSAE(cfg)\n", + " hooked_sae.load_state_dict(state_dict)\n", + " \n", + " hook_name_to_sae[cfg.hook_name] = hooked_sae\n", + "print(hook_name_to_sae.keys())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run with SAEs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The key feature of HookedSAETransformer is being able to \"splice in\" SAEs, replacing model activations with their SAE reconstructions. \n", + "\n", + "To run a forward pass with SAEs attached use `model.run_with_saes(tokens, saes=saes)`, where saes is a list of HookedSAEs that you want to add for just this forward pass. These will be reset immediately after the forward pass, returning the model to its original state.\n", + "\n", + "I expect this to be particularly useful for evaluating SAEs (eg [Gurnee](https://www.alignmentforum.org/posts/rZPiuFxESMxCDHe4B/sae-reconstruction-errors-are-empirically-pathological)), including evaluating how SAE reconstructions affect the models ability to perform certain tasks (eg [Makelov et al.](https://openreview.net/forum?id=MHIX9H8aYF&referrer=%5Bthe%20profile%20of%20Neel%20Nanda%5D(%2Fprofile%3Fid%3D~Neel_Nanda1)))\n", + "\n", + "To demonstrate, let's use `run_with_saes` to evaluate many combinations of SAEs on different cross sections of the IOI circuit.\n", + "\n", + "
\n", + "\n", + "Under the hood, TransformerLens already wraps activations with a HookPoint object. HookPoint is a dummy pytorch module that acts as an identity function by default, and is only used to access the activation with PyTorch hooks. When you run_with_saes, HookedSAETransformer temporarily replaces these HookPoints with the given HookedSAEs, which take the activation as input and replace it with the HookedSAE output (the reconstructed activation) during the forward pass. \n", + "\n", + "Since HookedSAE is a subclass of HookedRootModule, we also are able to add PyTorch hooks to the corresponding SAE activations, as we'll use later.\n", + "\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "error_y": { + "array": [ + 1.3678313493728638, + 1.6846193075180054, + 1.3839112520217896, + 1.6782633066177368, + 0.8939867615699768, + 2.2888872623443604 + ], + "type": "data", + "visible": true + }, + "type": "bar", + "x": [ + "Clean Baseline", + "With SAEs L[0, 3]", + "With SAEs L[2, 4]", + "With SAEs L[5, 6]", + "With SAEs L[7, 8]", + "With SAEs L[9, 10, 11]" + ], + "y": [ + 3.5518884658813477, + 2.580843925476074, + 3.3641157150268555, + 3.3500614166259766, + 1.5024915933609009, + 7.072007179260254 + ] + } + ], + "layout": { + "plot_bgcolor": "white", + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Logit Diff after Interventions" + }, + "xaxis": { + "title": { + "text": "Intervention" + } + }, + "yaxis": { + "title": { + "text": "Logit diff" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "all_layers = [[0, 3], [2, 4], [5,6], [7, 8], [9, 10, 11]]\n", + "x_axis = ['Clean Baseline']\n", + "per_prompt_logit_diffs = [\n", + " original_per_prompt_logit_diff, \n", + "]\n", + "\n", + "for layers in all_layers:\n", + " hooked_saes = [hook_name_to_sae[utils.get_act_name('z', layer)] for layer in layers]\n", + " logits_with_saes = model.run_with_saes(tokens, saes=hooked_saes)\n", + " average_logit_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens)\n", + " per_prompt_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens, per_prompt=True)\n", + " \n", + " x_axis.append(f\"With SAEs L{layers}\")\n", + " per_prompt_logit_diffs.append(per_prompt_diff_with_saes)\n", + "\n", + "show_avg_logit_diffs(x_axis, per_prompt_logit_diffs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run with cache (with SAEs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We often want to see what SAE features are active on a given prompt. With HookedSAETransformer, you can cache HookedSAE activations (and all the other standard activations) with `logits, cache = model.run_with_cache_with_saes(tokens, saes=saes)`. Just as `run_with_saes` is a wapper around the standard forward pass, `run_with_cache_with_saes` is a wrapper around `run_with_cache`, and will also only add these saes for one forward pass before returning the model to its original state. \n", + "\n", + "To access SAE activations from the cache, the corresponding hook names will generally be the HookedTransformer hook_name (eg blocks.5.attn.hook_z) + the hookedSAE hooked name preceeded by a period (eg .hook_sae_acts_post).\n", + "\n", + "`run_with_cache_with_saes` makes it easy to explore which SAE features are active across any input. Let's explore the active features at the S2 position for our L5 Attention SAE across all of our IOI prompts:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Feature Id: %{x}
Prompt: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "x": [ + "46", + "345", + "702", + "1372", + "1755", + "1965", + "2457", + "2496", + "2646", + "2999", + "3047", + "4569", + "5132", + "5203", + "5508", + "5940", + "6144", + "6371", + "6515", + "6558", + "6812", + "7092", + "7515", + "7907", + "8063", + "8623", + "8737", + "8768", + "9096", + "9102", + "9186", + "9463", + "9746", + "9913", + "10581", + "10894", + "12109", + "12485", + "12764", + "12866", + "13063", + "13624", + "13707", + "13777", + "14844", + "15050", + "15170", + "15696", + "16178", + "16892", + "17156", + "17259", + "17497", + "17854", + "18043", + "18210", + "18318", + "18385", + "18440", + "18920", + "19183", + "19263", + "19442", + "19524", + "19573", + "20838", + "21151", + "21657", + "22108", + "23578", + "24091", + "24217", + "25792", + "26373", + "26410", + "27535", + "27787", + "27811", + "27960", + "28061", + "28241", + "28242", + "28254", + "28349", + "28977", + "29027", + "29482", + "29603", + "29700", + "29822", + "32177", + "32920", + "33320", + "33730", + "33966", + "34177", + "34334", + "34947", + "35403", + "35425", + "35579", + "35665", + "35815", + "36109", + "36172", + "36451", + "36767", + "36917", + "38570", + "39962", + "40409", + "40418", + "40661", + "41162", + "41185", + "41552", + "42024", + "42161", + "42437", + "42577", + "42882", + "42931", + "43035", + "43414", + "43643", + "43662", + "44203", + "44256", + "44452", + "44652", + "45179", + "45814", + "45984", + "46880", + "47117", + "47170", + "47231", + "47313", + "47680", + "48063", + "48703" + ], + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.23392018675804138, + 0, + 0, + 0.04335343837738037, + 0.44275617599487305, + 0, + 0, + 0.07259953022003174, + 0, + 0.6985604763031006, + 1.262436866760254, + 0, + 0.04656928777694702, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.45666736364364624, + 0.10434150695800781, + 0.30980953574180603, + 0.3319076895713806, + 0, + 0, + 0, + 0, + 1.7836596965789795, + 0, + 0, + 0.142583429813385, + 0.046830952167510986, + 0.3180348575115204, + 0.2927079200744629, + 0.12267106771469116, + 2.5688514709472656, + 0.2917236089706421, + 0.12333670258522034, + 0, + 0.1778419017791748, + 0, + 0.023626387119293213, + 0.02943563461303711, + 0, + 0.048882365226745605, + 0.13625454902648926, + 0, + 0, + 0.2634885013103485, + 0, + 0, + 0, + 0.21662655472755432, + 0, + 0, + 0, + 0.06997489929199219, + 0.006345987319946289, + 0, + 0.16112494468688965, + 0.4190089702606201, + 0, + 2.3819468021392822, + 1.0431660413742065, + 0, + 0.08364987373352051, + 0, + 0, + 0.3451769948005676, + 0.7391350865364075, + 0.4456520080566406, + 0.0019606351852416992, + 0.39914217591285706, + 0, + 0, + 0, + 0.29958274960517883, + 0.44243645668029785, + 0, + 0.1259920299053192, + 0.8349504470825195, + 0.37993764877319336, + 0.2633737325668335, + 0.08324140310287476, + 0, + 0, + 0.10421907901763916, + 0, + 0, + 0, + 0.36972635984420776, + 0, + 0, + 0, + 0, + 0.5578295588493347, + 0, + 0.9233021140098572, + 0, + 0.10010790824890137, + 0, + 0.45082613825798035, + 0, + 0, + 0, + 0.21043556928634644, + 0.12981292605400085, + 0.11557984352111816, + 0, + 0, + 0.17571094632148743, + 0.2823787331581116, + 0.1122598648071289, + 0, + 0, + 0.012049257755279541, + 0, + 0, + 0, + 2.417463541030884, + 0.0547795295715332, + 0.05216425657272339, + 0, + 0.6592545509338379, + 0.003663182258605957, + 0, + 0, + 0.04937589168548584, + 0.025814831256866455, + 0, + 0.8019273281097412, + 0, + 0.10218703746795654 + ], + [ + 0, + 0, + 0.3230956792831421, + 0, + 0, + 0, + 0.026041746139526367, + 0.31818556785583496, + 0, + 0.4900796413421631, + 0.04911249876022339, + 0, + 0, + 0.07309412956237793, + 0.08089971542358398, + 0.17180073261260986, + 0, + 0, + 0, + 0, + 0, + 0, + 2.3956947326660156, + 0, + 0, + 0.15781426429748535, + 0, + 0.5073252320289612, + 0.21765804290771484, + 0, + 0, + 1.618570327758789, + 0, + 0.22485831379890442, + 0.0830467939376831, + 0.7055595517158508, + 0, + 0, + 0, + 0, + 0.23371747136116028, + 0, + 0, + 0.6983060240745544, + 0, + 0, + 0, + 0, + 0.30831730365753174, + 0, + 0.417669415473938, + 0.05292201042175293, + 0, + 0, + 0, + 1.3391070365905762, + 0, + 0.41352108120918274, + 0, + 0, + 0, + 0.037178993225097656, + 0, + 0, + 0, + 0, + 0.2702980041503906, + 0, + 0, + 0.18745100498199463, + 1.3330132961273193, + 0.5793700814247131, + 0.33893001079559326, + 0, + 0.11196631193161011, + 1.720167636871338, + 0.17581266164779663, + 0.42567259073257446, + 0, + 0, + 0.23682871460914612, + 0, + 0, + 0, + 0, + 0, + 1.8280882835388184, + 0.1617840826511383, + 0, + 0.13557660579681396, + 0.5832244157791138, + 0, + 0, + 0.03256487846374512, + 0, + 0, + 0.03892314434051514, + 0, + 0, + 0, + 0.30978846549987793, + 0, + 0, + 0.36915141344070435, + 0, + 0.5477294325828552, + 0, + 0, + 0.06339260935783386, + 0.1851767599582672, + 0.5839155912399292, + 0, + 0, + 0, + 0, + 0, + 0.12337607145309448, + 0, + 0, + 1.0378936529159546, + 0, + 0, + 0, + 0.01616498827934265, + 0.20259439945220947, + 0, + 0, + 0.3087460398674011, + 0.618510365486145, + 0.24435847997665405, + 0, + 0.4668591022491455, + 0.1788468360900879, + 0.200361967086792, + 0, + 0, + 0, + 0, + 0, + 0, + 0.7064645290374756 + ], + [ + 0.2921750843524933, + 0, + 0, + 0.2805737257003784, + 0, + 0, + 0, + 0.3694216012954712, + 0, + 1.1156601905822754, + 1.2807728052139282, + 0, + 0.09175515174865723, + 0, + 0, + 0, + 0.10458803176879883, + 0, + 0.021218180656433105, + 0, + 0, + 0.01699376106262207, + 0.09601330757141113, + 0.054788172245025635, + 0, + 0.030488133430480957, + 0.021512210369110107, + 0.2717320919036865, + 0.29357004165649414, + 0.6420693397521973, + 0.05249035358428955, + 0, + 0.06201601028442383, + 0, + 0.4122554659843445, + 1.821354866027832, + 0.01981794834136963, + 0, + 0.14063221216201782, + 0.05093127489089966, + 0, + 0.32148706912994385, + 0.15257668495178223, + 2.418062686920166, + 0.17348229885101318, + 0.08421656489372253, + 0, + 0.4551248550415039, + 0, + 0.015430927276611328, + 0.24434363842010498, + 0, + 0.06232607364654541, + 0, + 0.04422914981842041, + 0.8720088005065918, + 0.3721686899662018, + 0, + 0, + 0, + 0.340120404958725, + 0, + 0, + 0.07813769578933716, + 0, + 0.0882720947265625, + 0.19706517457962036, + 0.4056885242462158, + 0.19529414176940918, + 0, + 2.928431510925293, + 1.1402223110198975, + 0, + 0.026796698570251465, + 0.0033188462257385254, + 0, + 0.3370524048805237, + 0.47657889127731323, + 0, + 0.10358679294586182, + 0.27619925141334534, + 0, + 0, + 0, + 0.40909066796302795, + 0.2599871754646301, + 0, + 0.275011271238327, + 0.5349323749542236, + 0.07697033882141113, + 0.17431437969207764, + 0, + 0, + 0, + 0.09000074863433838, + 0, + 0, + 0, + 0.276567280292511, + 0, + 0, + 0, + 0, + 0.5655339360237122, + 0, + 0.8971189856529236, + 0, + 0.5199201107025146, + 0, + 0.6301102638244629, + 0.013657361268997192, + 0.04469645023345947, + 0.038062095642089844, + 0.4305816888809204, + 0, + 0.04173767566680908, + 0, + 0, + 0, + 0.8985729217529297, + 0, + 0, + 0, + 0, + 0, + 0.08318889141082764, + 0.006362795829772949, + 2.069222927093506, + 0, + 0.7068352103233337, + 0, + 0.8527798652648926, + 0, + 0, + 0.4707651138305664, + 0, + 0, + 0, + 0.7790955305099487, + 0.021227538585662842, + 0.01846003532409668 + ], + [ + 0, + 0, + 0.2200499176979065, + 0, + 0, + 0, + 0, + 0.2433047890663147, + 0.2504638135433197, + 0.712148904800415, + 0, + 0, + 0, + 0, + 0, + 0.1410943865776062, + 0, + 0, + 0, + 0.11292147636413574, + 0, + 0, + 2.360842704772949, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1.2830760478973389, + 0, + 0, + 0, + 0.6308119893074036, + 0, + 0.4040885865688324, + 0, + 0, + 0, + 0, + 0, + 0.5223236680030823, + 0, + 0, + 0, + 0, + 0.23784160614013672, + 0, + 0.04762387275695801, + 0, + 0, + 0, + 0, + 0.5758676528930664, + 0.01025208830833435, + 0.24556085467338562, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1.1104614734649658, + 1.079118251800537, + 0, + 0, + 0.14462929964065552, + 1.9186956882476807, + 0, + 0.30735498666763306, + 0, + 0, + 0.07669633626937866, + 0, + 0, + 0, + 0, + 0, + 1.3975048065185547, + 0, + 0, + 0.3461639881134033, + 0.5062156915664673, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.19610454142093658, + 0.218009352684021, + 0, + 0, + 0.07953745126724243, + 0, + 0.1416093111038208, + 0, + 0, + 0, + 0.18305465579032898, + 0.10310900211334229, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.45315277576446533, + 0, + 0, + 0, + 0.09076884388923645, + 0, + 0, + 0, + 0, + 0, + 0.04246491193771362, + 0, + 0.1807355284690857, + 0, + 0.3002055883407593, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0.02005404233932495, + 0, + 0, + 0.07601284980773926, + 0, + 0, + 0, + 0.012166053056716919, + 0, + 1.0662918090820312, + 1.4810535907745361, + 0, + 0.014786958694458008, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.1491186022758484, + 0, + 0, + 0, + 0.38226866722106934, + 0.43110355734825134, + 0, + 0, + 0, + 0, + 0, + 1.6819074153900146, + 0, + 0.7939910888671875, + 0.28643298149108887, + 0, + 0, + 0.011532962322235107, + 0, + 1.2869157791137695, + 0, + 0, + 0, + 0, + 0, + 0.16446048021316528, + 0, + 0, + 0, + 0, + 0, + 0, + 0.03375712037086487, + 0, + 0, + 0, + 0.1915181577205658, + 0, + 0, + 0.10225892066955566, + 0, + 0, + 0, + 0.7338485717773438, + 0, + 0, + 1.3715617656707764, + 1.6115869283676147, + 0, + 0.7128411531448364, + 0, + 0, + 0.2161598801612854, + 0.5098914504051208, + 0, + 0, + 0.04084053635597229, + 0, + 0, + 0, + 0.17978456616401672, + 0, + 0, + 0.1365671455860138, + 0.27122950553894043, + 0.2945059537887573, + 0.2824629545211792, + 0, + 0, + 0, + 0.0464092493057251, + 0, + 0, + 0.04672741889953613, + 0.6179839968681335, + 0, + 0, + 0, + 0, + 0.045598745346069336, + 0, + 1.0172381401062012, + 0, + 0.07242608070373535, + 0, + 0.5165215730667114, + 0, + 0, + 0, + 0.5004003047943115, + 0, + 0, + 0, + 0, + 0, + 0.3409433960914612, + 0, + 0.1579979658126831, + 0.09901612997055054, + 0, + 0, + 0, + 0, + 2.413944721221924, + 0, + 0.20971286296844482, + 0.07062971591949463, + 0.26070594787597656, + 0, + 0, + 0, + 0, + 0, + 0.020640969276428223, + 1.0534553527832031, + 0, + 0 + ], + [ + 0, + 0, + 0.046907246112823486, + 0, + 0, + 0, + 0, + 0.20885008573532104, + 0.25957152247428894, + 1.0767037868499756, + 0, + 0, + 0, + 0, + 0, + 0.23976856470108032, + 0, + 0, + 0, + 0, + 0, + 0, + 2.762990951538086, + 0, + 0, + 0, + 0, + 0.29466086626052856, + 0, + 0, + 0.09433537721633911, + 1.2446393966674805, + 0, + 0, + 0, + 0.6668079495429993, + 0, + 0.7482341527938843, + 0, + 0, + 0.005075186491012573, + 0, + 0, + 0.4049275517463684, + 0, + 0, + 0, + 0, + 0.09314888715744019, + 0, + 0, + 0, + 0, + 0, + 0, + 0.4028928279876709, + 0, + 0.3687801659107208, + 0, + 0.10555410385131836, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1.066054105758667, + 1.4596349000930786, + 0, + 0, + 0, + 2.3358588218688965, + 0, + 0.5390753149986267, + 0, + 0, + 0.12931063771247864, + 0, + 0.10619288682937622, + 0, + 0, + 0, + 0.41271400451660156, + 0, + 0, + 0.23865878582000732, + 0.7501264810562134, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.2947666645050049, + 0, + 0, + 0, + 0.05958199501037598, + 0.20450782775878906, + 0, + 0, + 0, + 0.13838836550712585, + 0.13835513591766357, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.45820748805999756, + 0, + 0, + 0, + 0.19962045550346375, + 0, + 0, + 0, + 0, + 0.20416772365570068, + 0.46223968267440796, + 0, + 0.22815394401550293, + 0, + 0.1125795841217041, + 0, + 0, + 0, + 0, + 0, + 0, + 0.3023688793182373 + ], + [ + 0.28365251421928406, + 0, + 0, + 0.41595208644866943, + 0, + 0.15376341342926025, + 0, + 0.22517156600952148, + 0, + 0.7871096134185791, + 1.3084614276885986, + 0.2012956142425537, + 0, + 0, + 0, + 0.2532406449317932, + 0.009012699127197266, + 0, + 0, + 0, + 0, + 0.7235959768295288, + 0.021468758583068848, + 0, + 0, + 0, + 0, + 0.8338297009468079, + 0.3022422790527344, + 0.6702529191970825, + 0.5416026711463928, + 0, + 0, + 0, + 0.2034381628036499, + 1.9052581787109375, + 0, + 0.23752644658088684, + 0, + 0, + 0, + 0.8470145463943481, + 0, + 2.820002555847168, + 0, + 0.16275432705879211, + 0.06714236736297607, + 0.12017238140106201, + 0, + 0, + 0, + 0, + 0.486280620098114, + 0, + 0, + 0.3096342086791992, + 0.3064201772212982, + 0, + 0.09773910045623779, + 0, + 0.4613642394542694, + 0, + 0.021892428398132324, + 0, + 0.18887782096862793, + 0.18538141250610352, + 0, + 0.42975664138793945, + 0.9873132705688477, + 0, + 2.163774013519287, + 1.2928048372268677, + 0, + 0.2320784330368042, + 0.0062233805656433105, + 0, + 1.2478563785552979, + 0.5479208827018738, + 0, + 0.06501156091690063, + 0.3741762936115265, + 0, + 0, + 0.31712013483047485, + 0.5228050947189331, + 0.3981531858444214, + 0, + 0, + 0.4854400157928467, + 0.3341655731201172, + 0.39207732677459717, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.3316766023635864, + 0, + 0, + 0.33435362577438354, + 0.1380615234375, + 0.7183249592781067, + 0.041296958923339844, + 0.7634149193763733, + 0, + 0.4028007984161377, + 0, + 0.6915435791015625, + 0, + 0, + 0, + 0.3831353187561035, + 0.05798754096031189, + 0.15244710445404053, + 0, + 0.03230410814285278, + 0.2039397656917572, + 0.6142292022705078, + 0.15542924404144287, + 0.07628917694091797, + 0.0812273919582367, + 0.15177401900291443, + 0.10224854946136475, + 0, + 0, + 2.8106069564819336, + 0.3994237184524536, + 0.6397127509117126, + 0, + 0.8949670791625977, + 0, + 0, + 0.18832790851593018, + 0.1450880765914917, + 0, + 0, + 0.6900937557220459, + 0, + 0.14745783805847168 + ], + [ + 0.12055802345275879, + 0.023864269256591797, + 0, + 0, + 0, + 0, + 0, + 0.3327372670173645, + 0.1789897382259369, + 1.1445300579071045, + 0, + 0, + 0, + 0, + 0, + 0.4361664652824402, + 0.09996795654296875, + 0.10051405429840088, + 0, + 0.4030296802520752, + 0.06672021746635437, + 0.6339577436447144, + 3.3947582244873047, + 0, + 0, + 0, + 0, + 0.9711236357688904, + 0, + 0.38066884875297546, + 0.4158353805541992, + 1.5344438552856445, + 0, + 0.19816407561302185, + 0, + 0.6646860241889954, + 0, + 0.16733816266059875, + 0, + 0, + 0, + 0.322623074054718, + 0, + 0.7314171195030212, + 0, + 0, + 0, + 0, + 0.043955981731414795, + 0, + 0, + 0, + 0, + 0, + 0, + 0.9436180591583252, + 0, + 0.29259607195854187, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.1570979356765747, + 0, + 0, + 0, + 1.1782727241516113, + 1.2431498765945435, + 0.32878363132476807, + 0, + 0.419150173664093, + 2.3304405212402344, + 0.8566346764564514, + 0, + 0, + 0, + 0.3841046392917633, + 0.10476112365722656, + 0, + 0.18140661716461182, + 0, + 0, + 0.6665420532226562, + 0, + 0, + 0.22877633571624756, + 0.9225524663925171, + 0, + 0.15886402130126953, + 0, + 0, + 0.02094721794128418, + 0, + 0, + 0, + 0.3046541213989258, + 0.2845715284347534, + 0, + 0, + 0.4244043231010437, + 0.164473295211792, + 0.30073386430740356, + 0.7123112678527832, + 0.1730642318725586, + 0, + 0.4041661322116852, + 0.39166414737701416, + 0, + 0, + 0.2103893756866455, + 0.007811635732650757, + 0.010994672775268555, + 0.03914850950241089, + 0, + 0, + 0.8430832624435425, + 0, + 0, + 0, + 0.15830591320991516, + 0.29398930072784424, + 0, + 0, + 0, + 0.5994948148727417, + 0.1704254150390625, + 0, + 0.4673898220062256, + 0, + 0.3204514980316162, + 0, + 0, + 0, + 0, + 0, + 0, + 0.8447363376617432 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Activations of Live SAE features at L5 S2 position per prompt" + }, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Feature Id" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Prompt" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "layer, s2_pos = 5, 10\n", + "saes = [hook_name_to_sae[utils.get_act_name('z', layer)]]\n", + "_, cache = model.run_with_cache_with_saes(tokens, saes=saes)\n", + "sae_acts = cache[utils.get_act_name('z', layer) + \".hook_sae_acts_post\"][:, s2_pos, :]\n", + "live_feature_mask = sae_acts > 0\n", + "live_feature_union = live_feature_mask.any(dim=0)\n", + "\n", + "imshow(\n", + " sae_acts[:, live_feature_union],\n", + " title = \"Activations of Live SAE features at L5 S2 position per prompt\",\n", + " xaxis=\"Feature Id\", yaxis=\"Prompt\",\n", + " x=list(map(str, live_feature_union.nonzero().flatten().tolist())),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We could then interpret some of the commonly activating features, like 7515, using [neuronpedia](https://www.neuronpedia.org/gpt2-small/5-att-kk/7515)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run with Hooks (with SAEs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "HookedSAETransformer also allows you to intervene on SAE activations with `model.run_with_hooks_with_saes(tokens, saes=saes, fwd_hooks=fwd_hooks)`. This works exactly like the standard TransformerLens `run_with_hooks`, with the added benefit that we can now intervene on SAE activations from the HookedSAEs that we splice in. Along the same lines as `run_with_saes` and `run_with_cache_with_saes`, this will only temporarily add SAEs before returning the model to it's original state. \n", + "\n", + "I expect this to be useful when doing circuit analysis with SAEs. To demonstrate, let's zero ablate individual layer 5 attention SAE features to localize causally important features." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 141/141 [00:04<00:00, 28.85it/s]\n" + ] + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Feature Idx: %{x}
Prompt Idx: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "x": [ + "46", + "345", + "702", + "1372", + "1755", + "1965", + "2457", + "2496", + "2646", + "2999", + "3047", + "4569", + "5132", + "5203", + "5508", + "5940", + "6144", + "6371", + "6515", + "6558", + "6812", + "7092", + "7515", + "7907", + "8063", + "8623", + "8737", + "8768", + "9096", + "9102", + "9186", + "9463", + "9746", + "9913", + "10581", + "10894", + "12109", + "12485", + "12764", + "12866", + "13063", + "13624", + "13707", + "13777", + "14844", + "15050", + "15170", + "15696", + "16178", + "16892", + "17156", + "17259", + "17497", + "17854", + "18043", + "18210", + "18318", + "18385", + "18440", + "18920", + "19183", + "19263", + "19442", + "19524", + "19573", + "20838", + "21151", + "21657", + "22108", + "23578", + "24091", + "24217", + "25792", + "26373", + "26410", + "27535", + "27787", + "27811", + "27960", + "28061", + "28241", + "28242", + "28254", + "28349", + "28977", + "29027", + "29482", + "29603", + "29700", + "29822", + "32177", + "32920", + "33320", + "33730", + "33966", + "34177", + "34334", + "34947", + "35403", + "35425", + "35579", + "35665", + "35815", + "36109", + "36172", + "36451", + "36767", + "36917", + "38570", + "39962", + "40409", + "40418", + "40661", + "41162", + "41185", + "41552", + "42024", + "42161", + "42437", + "42577", + "42882", + "42931", + "43035", + "43414", + "43643", + "43662", + "44203", + "44256", + "44452", + "44652", + "45179", + "45814", + "45984", + "46880", + "47117", + "47170", + "47231", + "47313", + "47680", + "48063", + "48703" + ], + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.006268501281738281, + 0, + 0, + 0.0016260147094726562, + 0.0011568069458007812, + 0, + 0, + -0.000400543212890625, + 0, + -0.024961471557617188, + -0.062079429626464844, + 0, + 0.00041866302490234375, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.017510414123535156, + -0.0021581649780273438, + -0.0012054443359375, + -0.006356239318847656, + 0, + 0, + 0, + 0, + 0.025524139404296875, + 0, + 0, + -0.0037746429443359375, + 0.0004291534423828125, + -0.000194549560546875, + 0.002796173095703125, + 0.0001850128173828125, + -0.056549072265625, + -0.0029163360595703125, + -0.004790306091308594, + 0, + 0.0005321502685546875, + 0, + 0.00049591064453125, + -0.0008335113525390625, + 0, + -0.00299072265625, + -0.00185394287109375, + 0, + 0, + 0.011702537536621094, + 0, + 0, + 0, + -0.003353118896484375, + 0, + 0, + 0, + 0.00048828125, + -0.000213623046875, + 0, + -0.0062160491943359375, + -0.007611274719238281, + 0, + 0.06644821166992188, + -0.025884628295898438, + 0, + -0.0001964569091796875, + 0, + 0, + 0.03233909606933594, + -0.05103874206542969, + 0.0003414154052734375, + -0.0000057220458984375, + -0.0027713775634765625, + 0, + 0, + 0, + -0.02438068389892578, + 0.027306556701660156, + 0, + -0.0036411285400390625, + 0.018335342407226562, + 0.010270118713378906, + 0.0120849609375, + 0.0013589859008789062, + 0, + 0, + -0.0033817291259765625, + 0, + 0, + 0, + -0.014057159423828125, + 0, + 0, + 0, + 0, + -0.008485794067382812, + 0, + 0.021463394165039062, + 0, + -0.002582550048828125, + 0, + 0.012966156005859375, + 0, + 0, + 0, + -0.0077991485595703125, + 0.002948760986328125, + 0.0069675445556640625, + 0, + 0, + 0.0058879852294921875, + -0.050632476806640625, + 0.001888275146484375, + 0, + 0, + -0.0005016326904296875, + 0, + 0, + 0, + -0.5087032318115234, + -0.0006818771362304688, + 0.0017566680908203125, + 0, + -0.02089214324951172, + -0.0000286102294921875, + 0, + 0, + -0.000446319580078125, + 0.0008115768432617188, + 0, + 0.017795562744140625, + 0, + -0.008462905883789062 + ], + [ + 0, + 0, + 0.0042266845703125, + 0, + 0, + 0, + -0.00130462646484375, + -0.01946258544921875, + 0, + 0.03999900817871094, + 0.013164520263671875, + 0, + 0, + -0.000522613525390625, + -0.0028820037841796875, + -0.003643035888671875, + 0, + 0, + 0, + 0, + 0, + 0, + -0.24383163452148438, + 0, + 0, + -0.0009517669677734375, + 0, + 0.05923271179199219, + 0.00897979736328125, + 0, + 0, + -0.00617218017578125, + 0, + 0.011938095092773438, + 0.005764007568359375, + 0.08927345275878906, + 0, + 0, + 0, + 0, + 0.027820587158203125, + 0, + 0, + 0.021488189697265625, + 0, + 0, + 0, + 0, + 0.016414642333984375, + 0, + -0.012666702270507812, + 0.002353668212890625, + 0, + 0, + 0, + 0.10541152954101562, + 0, + 0.010334014892578125, + 0, + 0, + 0, + 0.0012111663818359375, + 0, + 0, + 0, + 0, + -0.047576904296875, + 0, + 0, + -0.006137847900390625, + 0.04940223693847656, + 0.014007568359375, + 0.030317306518554688, + 0, + -0.0012969970703125, + -0.12521743774414062, + 0.0023975372314453125, + 0.04903602600097656, + 0, + 0, + 0.019681930541992188, + 0, + 0, + 0, + 0, + 0, + -0.07957077026367188, + -0.00966644287109375, + 0, + 0.011016845703125, + 0.05775642395019531, + 0, + 0, + 0.00060272216796875, + 0, + 0, + 0.00067138671875, + 0, + 0, + 0, + -0.0041980743408203125, + 0, + 0, + 0.020341873168945312, + 0, + -0.02782440185546875, + 0, + 0, + 0.001705169677734375, + 0.0035266876220703125, + 0.0060558319091796875, + 0, + 0, + 0, + 0, + 0, + 0.0004119873046875, + 0, + 0, + 0.10181617736816406, + 0, + 0, + 0, + 0.0001964569091796875, + 0.009687423706054688, + 0, + 0, + 0.10214805603027344, + 0.03883934020996094, + 0.028743743896484375, + 0, + -0.009389877319335938, + -0.0005168914794921875, + -0.0241851806640625, + 0, + 0, + 0, + 0, + 0, + 0, + 0.0089263916015625 + ], + [ + 0.013156890869140625, + 0, + 0, + 0.00737762451171875, + 0, + 0, + 0, + -0.011926651000976562, + 0, + -0.1016092300415039, + -0.2541160583496094, + 0, + 0.0026063919067382812, + 0, + 0, + 0, + 0.011356353759765625, + 0, + -0.0003261566162109375, + 0, + 0, + 0.000354766845703125, + 0.018985748291015625, + -0.0010251998901367188, + 0, + -0.0016918182373046875, + 0.00087738037109375, + -0.03418159484863281, + -0.022599220275878906, + -0.031129837036132812, + -0.0039033889770507812, + 0, + 0.002773284912109375, + 0, + -0.0497589111328125, + 0.0000972747802734375, + 0.00002002716064453125, + 0, + -0.000766754150390625, + 0.000133514404296875, + 0, + 0.00109100341796875, + 0.00045013427734375, + -0.15281009674072266, + -0.0027723312377929688, + -0.008421897888183594, + 0, + 0.024028778076171875, + 0, + 0.0008792877197265625, + -0.0008392333984375, + 0, + -0.014632225036621094, + 0, + -0.0009860992431640625, + -0.0236358642578125, + 0.021772384643554688, + 0, + 0, + 0, + -0.016798019409179688, + 0, + 0, + -0.0022678375244140625, + 0, + -0.0038995742797851562, + 0.006114959716796875, + -0.05572509765625, + -0.008089065551757812, + 0, + 0.21244430541992188, + -0.06043434143066406, + 0, + 0.0001010894775390625, + 0.00023651123046875, + 0, + 0.062018394470214844, + -0.08936023712158203, + 0, + -0.005387306213378906, + -0.001903533935546875, + 0, + 0, + 0, + -0.08661651611328125, + 0.020143508911132812, + 0, + -0.01000213623046875, + 0.008556365966796875, + -0.0023040771484375, + 0.0063114166259765625, + 0, + 0, + 0, + -0.01030731201171875, + 0, + 0, + 0, + -0.037540435791015625, + 0, + 0, + 0, + 0, + -0.018768310546875, + 0, + 0.06715202331542969, + 0, + -0.01861572265625, + 0, + 0.02222919464111328, + -0.0029458999633789062, + -0.0005445480346679688, + -0.001338958740234375, + -0.0246734619140625, + 0, + 0.0014019012451171875, + 0, + 0, + 0, + -0.34259986877441406, + 0, + 0, + 0, + 0, + 0, + -0.002704620361328125, + -0.0001850128173828125, + -0.9704685211181641, + 0, + -0.01996612548828125, + 0, + -0.0199432373046875, + 0, + 0, + 0.025028228759765625, + 0, + 0, + 0, + 0.05844879150390625, + -0.00006961822509765625, + -0.002410888671875 + ], + [ + 0, + 0, + -0.001018524169921875, + 0, + 0, + 0, + 0, + -0.0172882080078125, + 0.05738639831542969, + 0.12810707092285156, + 0, + 0, + 0, + 0, + 0, + -0.0056362152099609375, + 0, + 0, + 0, + 0.009425163269042969, + 0, + 0, + -0.2314128875732422, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.057198524475097656, + 0, + 0, + 0, + 0.13471412658691406, + 0, + 0.08182525634765625, + 0, + 0, + 0, + 0, + 0, + 0.006465911865234375, + 0, + 0, + 0, + 0, + 0.0039052963256835938, + 0, + -0.0010318756103515625, + 0, + 0, + 0, + 0, + 0.062198638916015625, + 0.0000057220458984375, + -0.001708984375, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.03947257995605469, + 0.1576099395751953, + 0, + 0, + 0.00009822845458984375, + -0.25530242919921875, + 0, + 0.061611175537109375, + 0, + 0, + 0.0061016082763671875, + 0, + 0, + 0, + 0, + 0, + -0.079315185546875, + 0, + 0, + 0.04389762878417969, + 0.06207084655761719, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.0064945220947265625, + -0.009065628051757812, + 0, + 0, + 0.0025882720947265625, + 0, + 0.0033740997314453125, + 0, + 0, + 0, + 0.014276504516601562, + -0.011219978332519531, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.023397445678710938, + 0, + 0, + 0, + 0.0096435546875, + 0, + 0, + 0, + 0, + 0, + 0.007327079772949219, + 0, + 0.00238037109375, + 0, + -0.04556846618652344, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + -0.0007219314575195312, + 0, + 0, + -0.001102447509765625, + 0, + 0, + 0, + -0.00047397613525390625, + 0, + -0.02031421661376953, + -0.18840694427490234, + 0, + 0.0009374618530273438, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.0014810562133789062, + 0, + 0, + 0, + -0.01897907257080078, + -0.012393951416015625, + 0, + 0, + 0, + 0, + 0, + -0.007961273193359375, + 0, + 0.006266593933105469, + 0.022070884704589844, + 0, + 0, + -0.00022220611572265625, + 0, + -0.08554744720458984, + 0, + 0, + 0, + 0, + 0, + 0.00211334228515625, + 0, + 0, + 0, + 0, + 0, + 0, + -0.0006618499755859375, + 0, + 0, + 0, + 0.00042629241943359375, + 0, + 0, + -0.0023794174194335938, + 0, + 0, + 0, + -0.08295249938964844, + 0, + 0, + 0.02340221405029297, + 0.05393028259277344, + 0, + 0.0030164718627929688, + 0, + 0, + 0.02137470245361328, + -0.0648040771484375, + 0, + 0, + -0.0007104873657226562, + 0, + 0, + 0, + -0.02891063690185547, + 0, + 0, + -0.0024862289428710938, + -0.007077217102050781, + -0.004982948303222656, + 0.004157066345214844, + 0, + 0, + 0, + -0.0009584426879882812, + 0, + 0, + -0.0016260147094726562, + -0.03653144836425781, + 0, + 0, + 0, + 0, + -0.004261970520019531, + 0, + 0.1517467498779297, + 0, + -0.0017957687377929688, + 0, + 0.01949596405029297, + 0, + 0, + 0, + -0.024643898010253906, + 0, + 0, + 0, + 0, + 0, + -0.12193775177001953, + 0, + 0.01824474334716797, + 0.006918907165527344, + 0, + 0, + 0, + 0, + -0.5964584350585938, + 0, + -0.004886627197265625, + -0.0028219223022460938, + -0.013730049133300781, + 0, + 0, + 0, + 0, + 0, + 0.000370025634765625, + 0.11502552032470703, + 0, + 0 + ], + [ + 0, + 0, + 0.0020799636840820312, + 0, + 0, + 0, + 0, + -0.02874469757080078, + 0.0672769546508789, + 0.31006431579589844, + 0, + 0, + 0, + 0, + 0, + -0.014065742492675781, + 0, + 0, + 0, + 0, + 0, + 0, + -0.42875194549560547, + 0, + 0, + 0, + 0, + 0.037166595458984375, + 0, + 0, + 0.00395965576171875, + -0.09044742584228516, + 0, + 0, + 0, + 0.16284751892089844, + 0, + 0.2745513916015625, + 0, + 0, + 0.0013599395751953125, + 0, + 0, + -0.016633033752441406, + 0, + 0, + 0, + 0, + 0.002765655517578125, + 0, + 0, + 0, + 0, + 0, + 0, + 0.06857013702392578, + 0, + 0.0030755996704101562, + 0, + 0.005713462829589844, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.010555267333984375, + 0.35628509521484375, + 0, + 0, + 0, + -0.3705453872680664, + 0, + 0.1321268081665039, + 0, + 0, + 0.01171875, + 0, + 0.006653785705566406, + 0, + 0, + 0, + -0.04768085479736328, + 0, + 0, + 0.05365467071533203, + 0.10848140716552734, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.0019636154174804688, + 0, + 0, + 0, + -0.0038604736328125, + -0.00696563720703125, + 0, + 0, + 0, + 0.004207611083984375, + -0.009866714477539062, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.041828155517578125, + 0, + 0, + 0, + 0.03432941436767578, + 0, + 0, + 0, + 0, + 0.02262592315673828, + 0.1012563705444336, + 0, + 0.0032415390014648438, + 0, + -0.028539657592773438, + 0, + 0, + 0, + 0, + 0, + 0, + 0.019530296325683594 + ], + [ + 0.0072574615478515625, + 0, + 0, + 0.0045604705810546875, + 0, + -0.002410888671875, + 0, + 0.000942230224609375, + 0, + -0.028242111206054688, + -0.06697559356689453, + -0.002197265625, + 0, + 0, + 0, + 0.01448822021484375, + 0.00038909912109375, + 0, + 0, + 0, + 0, + -0.0072345733642578125, + 0.0015048980712890625, + 0, + 0, + 0, + 0, + -0.026609420776367188, + -0.007898330688476562, + 0.006641387939453125, + -0.012470245361328125, + 0, + 0, + 0, + -0.0054531097412109375, + 0.06533622741699219, + 0, + 0.00041484832763671875, + 0, + 0, + 0, + -0.002368927001953125, + 0, + 0.04226112365722656, + 0, + -0.0031299591064453125, + -0.0000457763671875, + 0.000308990478515625, + 0, + 0, + 0, + 0, + -0.0275726318359375, + 0, + 0, + -0.004794120788574219, + 0.01718902587890625, + 0, + -0.001049041748046875, + 0, + -0.007875442504882812, + 0, + -0.00032806396484375, + 0, + 0.002880096435546875, + -0.0073566436767578125, + 0, + -0.012141227722167969, + -0.002796173095703125, + 0, + 0.0904073715209961, + -0.020002365112304688, + 0, + 0.0006046295166015625, + 0.0000095367431640625, + 0, + 0.09020233154296875, + -0.024329185485839844, + 0, + -0.0007257461547851562, + 0.0022792816162109375, + 0, + 0, + 0.0024671554565429688, + -0.031095504760742188, + 0.029073715209960938, + 0, + 0, + 0.017263412475585938, + 0.009774208068847656, + 0.01905059814453125, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.007511138916015625, + 0, + 0, + -0.01740264892578125, + -0.012363433837890625, + -0.007237434387207031, + 0.00046825408935546875, + 0.015039443969726562, + 0, + -0.001247406005859375, + 0, + 0.04442596435546875, + 0, + 0, + 0, + 0.0020885467529296875, + 0.0009975433349609375, + 0.0068645477294921875, + 0, + 0.0009918212890625, + 0.007763862609863281, + -0.10830020904541016, + 0.002170562744140625, + 0.0041522979736328125, + 0.0009832382202148438, + -0.0055789947509765625, + -0.0020475387573242188, + 0, + 0, + -0.46219825744628906, + -0.0004138946533203125, + 0.022248268127441406, + 0, + -0.023275375366210938, + 0, + 0, + -0.00007152557373046875, + -0.0017099380493164062, + 0, + 0, + 0.028047561645507812, + 0, + -0.006505012512207031 + ], + [ + 0.0026121139526367188, + 0.0023622512817382812, + 0, + 0, + 0, + 0, + 0, + -0.04861927032470703, + 0.04393959045410156, + 0.24942588806152344, + 0, + 0, + 0, + 0, + 0, + -0.0894918441772461, + 0.011738777160644531, + 0.0023365020751953125, + 0, + 0.03142070770263672, + 0.007035255432128906, + 0.013895988464355469, + -0.38878440856933594, + 0, + 0, + 0, + 0, + 0.3524456024169922, + 0, + 0.04943275451660156, + 0.07975196838378906, + -0.13926124572753906, + 0, + 0.007584571838378906, + 0, + 0.10158729553222656, + 0, + 0.048768043518066406, + 0, + 0, + 0, + -0.010777473449707031, + 0, + -0.02371692657470703, + 0, + 0, + 0, + 0, + -0.0021333694458007812, + 0, + 0, + 0, + 0, + 0, + 0, + 0.14519309997558594, + 0, + -0.023756027221679688, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.038219451904296875, + 0, + 0, + 0, + -0.07305049896240234, + 0.1724720001220703, + 0.035521507263183594, + 0, + 0.026566505432128906, + -0.2165508270263672, + -0.010828971862792969, + 0, + 0, + 0, + 0.06682586669921875, + 0.0020055770874023438, + 0, + 0.05693340301513672, + 0, + 0, + -0.1571969985961914, + 0, + 0, + 0.0275726318359375, + 0.09813213348388672, + 0, + -0.0074253082275390625, + 0, + 0, + -0.00006008148193359375, + 0, + 0, + 0, + 0.007464408874511719, + -0.011278152465820312, + 0, + 0, + 0.008585929870605469, + -0.02161121368408203, + -0.05259227752685547, + 0.15187358856201172, + 0.009034156799316406, + 0, + 0.01724529266357422, + 0.02186107635498047, + 0, + 0, + 0.023595809936523438, + 0.0018739700317382812, + 0.0014142990112304688, + 0.0001888275146484375, + 0, + 0, + 0.14745807647705078, + 0, + 0, + 0, + 0.022150039672851562, + 0.04754352569580078, + 0, + 0, + 0, + 0.12122058868408203, + 0.037743568420410156, + 0, + -0.022559165954589844, + 0, + -0.07815361022949219, + 0, + 0, + 0, + 0, + 0, + 0, + 0.1304798126220703 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Change in logit diff when ablating L5 SAE features for all prompts at pos 10" + }, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Feature Idx" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Prompt Idx" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def ablate_sae_feature(sae_acts, hook, pos, feature_id):\n", + " if pos is None:\n", + " sae_acts[:, :, feature_id] = 0.\n", + " else:\n", + " sae_acts[:, pos, feature_id] = 0.\n", + " return sae_acts\n", + "\n", + "layer = 5\n", + "sae = hook_name_to_sae[utils.get_act_name('z', layer)]\n", + "\n", + "logits_with_saes = model.run_with_saes(tokens, saes=sae)\n", + "clean_sae_baseline_per_prompt = logits_to_ave_logit_diff(logits_with_saes, answer_tokens, per_prompt=True)\n", + "\n", + "all_live_features = torch.arange(sae.cfg.d_sae)[live_feature_union.cpu()]\n", + "\n", + "causal_effects = torch.zeros((len(prompts), all_live_features.shape[0]))\n", + "fid_to_idx = {fid.item(): idx for idx, fid in enumerate(all_live_features)}\n", + "\n", + "\n", + "abl_layer, abl_pos = 5, 10\n", + "for feature_id in tqdm.tqdm(all_live_features):\n", + " feature_id = feature_id.item()\n", + " abl_feature_logits = model.run_with_hooks_with_saes(\n", + " tokens,\n", + " saes=sae,\n", + " fwd_hooks=[(utils.get_act_name('z', abl_layer) + \".hook_sae_acts_post\", partial(ablate_sae_feature, pos=abl_pos, feature_id=feature_id))]\n", + " ) # [batch, seq, vocab]\n", + " \n", + " abl_feature_logit_diff = logits_to_ave_logit_diff(abl_feature_logits, answer_tokens, per_prompt=True) # [batch]\n", + " causal_effects[:, fid_to_idx[feature_id]] = abl_feature_logit_diff - clean_sae_baseline_per_prompt\n", + "\n", + "\n", + "imshow(causal_effects, title=f\"Change in logit diff when ablating L{abl_layer} SAE features for all prompts at pos {abl_pos}\", xaxis=\"Feature Idx\", yaxis=\"Prompt Idx\", x=list(map(str, all_live_features.tolist())))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Although it's not super clean, we see a few features stand out, where ablating them causes a nontrivial drop in logit diff on multiple prompts: 7515 and 27535 for BABA prompts, with 44256 for ABBA prompts." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Add SAEs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "While the `run_with_saes` family of methods are great for evaluating SAEs and exploratory analysis, you may want to permanently attach SAEs to your model. You can attach SAEs to any activation with `model.add_sae(sae)`, where sae is a HookedSAE. \n", + "\n", + "When you add an SAE, it gets stored in `model.acts_to_saes`, a dictionary that maps the activation name to the HookedSAE that is attached. The main benefit of permanently adding SAEs is that we can now just run the model like a normal HookedTransformer (with `forward`, `run_with_cache`, `run_with_hooks`), but some activations will be replaced with the reconstructed activations from the corresponding SAEs.\n", + "\n", + "I expect this to be most useful when you've already identified a good set of SAEs that you want to use for interpretability, and don't feel like passing in a massive list of saes for every forward pass." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Attached SAEs before add_sae {}\n", + "Attached SAEs after add_sae {'blocks.5.attn.hook_z': HookedSAE(\n", + " (hook_sae_input): HookPoint()\n", + " (hook_sae_acts_pre): HookPoint()\n", + " (hook_sae_acts_post): HookPoint()\n", + " (hook_sae_recons): HookPoint()\n", + " (hook_sae_error): HookPoint()\n", + " (hook_sae_output): HookPoint()\n", + ")}\n" + ] + } + ], + "source": [ + "print(\"Attached SAEs before add_sae\", model.acts_to_saes)\n", + "layer = 5\n", + "sae = hook_name_to_sae[utils.get_act_name('z', layer)]\n", + "model.add_sae(sae)\n", + "print(\"Attached SAEs after add_sae\", model.acts_to_saes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can just call the standard HookedTransformer forward, and the sae that we added will automatically be spliced in." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average logit diff with SAEs: 3.6155965328216553\n" + ] + } + ], + "source": [ + "logits_with_saes = model(tokens)\n", + "assert not torch.allclose(original_logits, logits_with_saes, atol=1e-4)\n", + "\n", + "average_logit_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens)\n", + "print(f\"Average logit diff with SAEs: {average_logit_diff_with_saes}\")\n", + "per_prompt_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens, per_prompt=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run with cache" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Similarly, we can also use `logits, cache = model.run_with_cache(tokens)` directly to cache SAE activations:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Feature Id: %{x}
Prompt: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "x": [ + "46", + "345", + "702", + "1372", + "1755", + "1965", + "2457", + "2496", + "2646", + "2999", + "3047", + "4569", + "5132", + "5203", + "5508", + "5940", + "6144", + "6371", + "6515", + "6558", + "6812", + "7092", + "7515", + "7907", + "8063", + "8623", + "8737", + "8768", + "9096", + "9102", + "9186", + "9463", + "9746", + "9913", + "10581", + "10894", + "12109", + "12485", + "12764", + "12866", + "13063", + "13624", + "13707", + "13777", + "14844", + "15050", + "15170", + "15696", + "16178", + "16892", + "17156", + "17259", + "17497", + "17854", + "18043", + "18210", + "18318", + "18385", + "18440", + "18920", + "19183", + "19263", + "19442", + "19524", + "19573", + "20838", + "21151", + "21657", + "22108", + "23578", + "24091", + "24217", + "25792", + "26373", + "26410", + "27535", + "27787", + "27811", + "27960", + "28061", + "28241", + "28242", + "28254", + "28349", + "28977", + "29027", + "29482", + "29603", + "29700", + "29822", + "32177", + "32920", + "33320", + "33730", + "33966", + "34177", + "34334", + "34947", + "35403", + "35425", + "35579", + "35665", + "35815", + "36109", + "36172", + "36451", + "36767", + "36917", + "38570", + "39962", + "40409", + "40418", + "40661", + "41162", + "41185", + "41552", + "42024", + "42161", + "42437", + "42577", + "42882", + "42931", + "43035", + "43414", + "43643", + "43662", + "44203", + "44256", + "44452", + "44652", + "45179", + "45814", + "45984", + "46880", + "47117", + "47170", + "47231", + "47313", + "47680", + "48063", + "48703" + ], + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.23392018675804138, + 0, + 0, + 0.04335343837738037, + 0.44275617599487305, + 0, + 0, + 0.07259953022003174, + 0, + 0.6985604763031006, + 1.262436866760254, + 0, + 0.04656928777694702, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.45666736364364624, + 0.10434150695800781, + 0.30980953574180603, + 0.3319076895713806, + 0, + 0, + 0, + 0, + 1.7836596965789795, + 0, + 0, + 0.142583429813385, + 0.046830952167510986, + 0.3180348575115204, + 0.2927079200744629, + 0.12267106771469116, + 2.5688514709472656, + 0.2917236089706421, + 0.12333670258522034, + 0, + 0.1778419017791748, + 0, + 0.023626387119293213, + 0.02943563461303711, + 0, + 0.048882365226745605, + 0.13625454902648926, + 0, + 0, + 0.2634885013103485, + 0, + 0, + 0, + 0.21662655472755432, + 0, + 0, + 0, + 0.06997489929199219, + 0.006345987319946289, + 0, + 0.16112494468688965, + 0.4190089702606201, + 0, + 2.3819468021392822, + 1.0431660413742065, + 0, + 0.08364987373352051, + 0, + 0, + 0.3451769948005676, + 0.7391350865364075, + 0.4456520080566406, + 0.0019606351852416992, + 0.39914217591285706, + 0, + 0, + 0, + 0.29958274960517883, + 0.44243645668029785, + 0, + 0.1259920299053192, + 0.8349504470825195, + 0.37993764877319336, + 0.2633737325668335, + 0.08324140310287476, + 0, + 0, + 0.10421907901763916, + 0, + 0, + 0, + 0.36972635984420776, + 0, + 0, + 0, + 0, + 0.5578295588493347, + 0, + 0.9233021140098572, + 0, + 0.10010790824890137, + 0, + 0.45082613825798035, + 0, + 0, + 0, + 0.21043556928634644, + 0.12981292605400085, + 0.11557984352111816, + 0, + 0, + 0.17571094632148743, + 0.2823787331581116, + 0.1122598648071289, + 0, + 0, + 0.012049257755279541, + 0, + 0, + 0, + 2.417463541030884, + 0.0547795295715332, + 0.05216425657272339, + 0, + 0.6592545509338379, + 0.003663182258605957, + 0, + 0, + 0.04937589168548584, + 0.025814831256866455, + 0, + 0.8019273281097412, + 0, + 0.10218703746795654 + ], + [ + 0, + 0, + 0.3230956792831421, + 0, + 0, + 0, + 0.026041746139526367, + 0.31818556785583496, + 0, + 0.4900796413421631, + 0.04911249876022339, + 0, + 0, + 0.07309412956237793, + 0.08089971542358398, + 0.17180073261260986, + 0, + 0, + 0, + 0, + 0, + 0, + 2.3956947326660156, + 0, + 0, + 0.15781426429748535, + 0, + 0.5073252320289612, + 0.21765804290771484, + 0, + 0, + 1.618570327758789, + 0, + 0.22485831379890442, + 0.0830467939376831, + 0.7055595517158508, + 0, + 0, + 0, + 0, + 0.23371747136116028, + 0, + 0, + 0.6983060240745544, + 0, + 0, + 0, + 0, + 0.30831730365753174, + 0, + 0.417669415473938, + 0.05292201042175293, + 0, + 0, + 0, + 1.3391070365905762, + 0, + 0.41352108120918274, + 0, + 0, + 0, + 0.037178993225097656, + 0, + 0, + 0, + 0, + 0.2702980041503906, + 0, + 0, + 0.18745100498199463, + 1.3330132961273193, + 0.5793700814247131, + 0.33893001079559326, + 0, + 0.11196631193161011, + 1.720167636871338, + 0.17581266164779663, + 0.42567259073257446, + 0, + 0, + 0.23682871460914612, + 0, + 0, + 0, + 0, + 0, + 1.8280882835388184, + 0.1617840826511383, + 0, + 0.13557660579681396, + 0.5832244157791138, + 0, + 0, + 0.03256487846374512, + 0, + 0, + 0.03892314434051514, + 0, + 0, + 0, + 0.30978846549987793, + 0, + 0, + 0.36915141344070435, + 0, + 0.5477294325828552, + 0, + 0, + 0.06339260935783386, + 0.1851767599582672, + 0.5839155912399292, + 0, + 0, + 0, + 0, + 0, + 0.12337607145309448, + 0, + 0, + 1.0378936529159546, + 0, + 0, + 0, + 0.01616498827934265, + 0.20259439945220947, + 0, + 0, + 0.3087460398674011, + 0.618510365486145, + 0.24435847997665405, + 0, + 0.4668591022491455, + 0.1788468360900879, + 0.200361967086792, + 0, + 0, + 0, + 0, + 0, + 0, + 0.7064645290374756 + ], + [ + 0.2921750843524933, + 0, + 0, + 0.2805737257003784, + 0, + 0, + 0, + 0.3694216012954712, + 0, + 1.1156601905822754, + 1.2807728052139282, + 0, + 0.09175515174865723, + 0, + 0, + 0, + 0.10458803176879883, + 0, + 0.021218180656433105, + 0, + 0, + 0.01699376106262207, + 0.09601330757141113, + 0.054788172245025635, + 0, + 0.030488133430480957, + 0.021512210369110107, + 0.2717320919036865, + 0.29357004165649414, + 0.6420693397521973, + 0.05249035358428955, + 0, + 0.06201601028442383, + 0, + 0.4122554659843445, + 1.821354866027832, + 0.01981794834136963, + 0, + 0.14063221216201782, + 0.05093127489089966, + 0, + 0.32148706912994385, + 0.15257668495178223, + 2.418062686920166, + 0.17348229885101318, + 0.08421656489372253, + 0, + 0.4551248550415039, + 0, + 0.015430927276611328, + 0.24434363842010498, + 0, + 0.06232607364654541, + 0, + 0.04422914981842041, + 0.8720088005065918, + 0.3721686899662018, + 0, + 0, + 0, + 0.340120404958725, + 0, + 0, + 0.07813769578933716, + 0, + 0.0882720947265625, + 0.19706517457962036, + 0.4056885242462158, + 0.19529414176940918, + 0, + 2.928431510925293, + 1.1402223110198975, + 0, + 0.026796698570251465, + 0.0033188462257385254, + 0, + 0.3370524048805237, + 0.47657889127731323, + 0, + 0.10358679294586182, + 0.27619925141334534, + 0, + 0, + 0, + 0.40909066796302795, + 0.2599871754646301, + 0, + 0.275011271238327, + 0.5349323749542236, + 0.07697033882141113, + 0.17431437969207764, + 0, + 0, + 0, + 0.09000074863433838, + 0, + 0, + 0, + 0.276567280292511, + 0, + 0, + 0, + 0, + 0.5655339360237122, + 0, + 0.8971189856529236, + 0, + 0.5199201107025146, + 0, + 0.6301102638244629, + 0.013657361268997192, + 0.04469645023345947, + 0.038062095642089844, + 0.4305816888809204, + 0, + 0.04173767566680908, + 0, + 0, + 0, + 0.8985729217529297, + 0, + 0, + 0, + 0, + 0, + 0.08318889141082764, + 0.006362795829772949, + 2.069222927093506, + 0, + 0.7068352103233337, + 0, + 0.8527798652648926, + 0, + 0, + 0.4707651138305664, + 0, + 0, + 0, + 0.7790955305099487, + 0.021227538585662842, + 0.01846003532409668 + ], + [ + 0, + 0, + 0.2200499176979065, + 0, + 0, + 0, + 0, + 0.2433047890663147, + 0.2504638135433197, + 0.712148904800415, + 0, + 0, + 0, + 0, + 0, + 0.1410943865776062, + 0, + 0, + 0, + 0.11292147636413574, + 0, + 0, + 2.360842704772949, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1.2830760478973389, + 0, + 0, + 0, + 0.6308119893074036, + 0, + 0.4040885865688324, + 0, + 0, + 0, + 0, + 0, + 0.5223236680030823, + 0, + 0, + 0, + 0, + 0.23784160614013672, + 0, + 0.04762387275695801, + 0, + 0, + 0, + 0, + 0.5758676528930664, + 0.01025208830833435, + 0.24556085467338562, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1.1104614734649658, + 1.079118251800537, + 0, + 0, + 0.14462929964065552, + 1.9186956882476807, + 0, + 0.30735498666763306, + 0, + 0, + 0.07669633626937866, + 0, + 0, + 0, + 0, + 0, + 1.3975048065185547, + 0, + 0, + 0.3461639881134033, + 0.5062156915664673, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.19610454142093658, + 0.218009352684021, + 0, + 0, + 0.07953745126724243, + 0, + 0.1416093111038208, + 0, + 0, + 0, + 0.18305465579032898, + 0.10310900211334229, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.45315277576446533, + 0, + 0, + 0, + 0.09076884388923645, + 0, + 0, + 0, + 0, + 0, + 0.04246491193771362, + 0, + 0.1807355284690857, + 0, + 0.3002055883407593, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0.02005404233932495, + 0, + 0, + 0.07601284980773926, + 0, + 0, + 0, + 0.012166053056716919, + 0, + 1.0662918090820312, + 1.4810535907745361, + 0, + 0.014786958694458008, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.1491186022758484, + 0, + 0, + 0, + 0.38226866722106934, + 0.43110355734825134, + 0, + 0, + 0, + 0, + 0, + 1.6819074153900146, + 0, + 0.7939910888671875, + 0.28643298149108887, + 0, + 0, + 0.011532962322235107, + 0, + 1.2869157791137695, + 0, + 0, + 0, + 0, + 0, + 0.16446048021316528, + 0, + 0, + 0, + 0, + 0, + 0, + 0.03375712037086487, + 0, + 0, + 0, + 0.1915181577205658, + 0, + 0, + 0.10225892066955566, + 0, + 0, + 0, + 0.7338485717773438, + 0, + 0, + 1.3715617656707764, + 1.6115869283676147, + 0, + 0.7128411531448364, + 0, + 0, + 0.2161598801612854, + 0.5098914504051208, + 0, + 0, + 0.04084053635597229, + 0, + 0, + 0, + 0.17978456616401672, + 0, + 0, + 0.1365671455860138, + 0.27122950553894043, + 0.2945059537887573, + 0.2824629545211792, + 0, + 0, + 0, + 0.0464092493057251, + 0, + 0, + 0.04672741889953613, + 0.6179839968681335, + 0, + 0, + 0, + 0, + 0.045598745346069336, + 0, + 1.0172381401062012, + 0, + 0.07242608070373535, + 0, + 0.5165215730667114, + 0, + 0, + 0, + 0.5004003047943115, + 0, + 0, + 0, + 0, + 0, + 0.3409433960914612, + 0, + 0.1579979658126831, + 0.09901612997055054, + 0, + 0, + 0, + 0, + 2.413944721221924, + 0, + 0.20971286296844482, + 0.07062971591949463, + 0.26070594787597656, + 0, + 0, + 0, + 0, + 0, + 0.020640969276428223, + 1.0534553527832031, + 0, + 0 + ], + [ + 0, + 0, + 0.046907246112823486, + 0, + 0, + 0, + 0, + 0.20885008573532104, + 0.25957152247428894, + 1.0767037868499756, + 0, + 0, + 0, + 0, + 0, + 0.23976856470108032, + 0, + 0, + 0, + 0, + 0, + 0, + 2.762990951538086, + 0, + 0, + 0, + 0, + 0.29466086626052856, + 0, + 0, + 0.09433537721633911, + 1.2446393966674805, + 0, + 0, + 0, + 0.6668079495429993, + 0, + 0.7482341527938843, + 0, + 0, + 0.005075186491012573, + 0, + 0, + 0.4049275517463684, + 0, + 0, + 0, + 0, + 0.09314888715744019, + 0, + 0, + 0, + 0, + 0, + 0, + 0.4028928279876709, + 0, + 0.3687801659107208, + 0, + 0.10555410385131836, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1.066054105758667, + 1.4596349000930786, + 0, + 0, + 0, + 2.3358588218688965, + 0, + 0.5390753149986267, + 0, + 0, + 0.12931063771247864, + 0, + 0.10619288682937622, + 0, + 0, + 0, + 0.41271400451660156, + 0, + 0, + 0.23865878582000732, + 0.7501264810562134, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.2947666645050049, + 0, + 0, + 0, + 0.05958199501037598, + 0.20450782775878906, + 0, + 0, + 0, + 0.13838836550712585, + 0.13835513591766357, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.45820748805999756, + 0, + 0, + 0, + 0.19962045550346375, + 0, + 0, + 0, + 0, + 0.20416772365570068, + 0.46223968267440796, + 0, + 0.22815394401550293, + 0, + 0.1125795841217041, + 0, + 0, + 0, + 0, + 0, + 0, + 0.3023688793182373 + ], + [ + 0.28365251421928406, + 0, + 0, + 0.41595208644866943, + 0, + 0.15376341342926025, + 0, + 0.22517156600952148, + 0, + 0.7871096134185791, + 1.3084614276885986, + 0.2012956142425537, + 0, + 0, + 0, + 0.2532406449317932, + 0.009012699127197266, + 0, + 0, + 0, + 0, + 0.7235959768295288, + 0.021468758583068848, + 0, + 0, + 0, + 0, + 0.8338297009468079, + 0.3022422790527344, + 0.6702529191970825, + 0.5416026711463928, + 0, + 0, + 0, + 0.2034381628036499, + 1.9052581787109375, + 0, + 0.23752644658088684, + 0, + 0, + 0, + 0.8470145463943481, + 0, + 2.820002555847168, + 0, + 0.16275432705879211, + 0.06714236736297607, + 0.12017238140106201, + 0, + 0, + 0, + 0, + 0.486280620098114, + 0, + 0, + 0.3096342086791992, + 0.3064201772212982, + 0, + 0.09773910045623779, + 0, + 0.4613642394542694, + 0, + 0.021892428398132324, + 0, + 0.18887782096862793, + 0.18538141250610352, + 0, + 0.42975664138793945, + 0.9873132705688477, + 0, + 2.163774013519287, + 1.2928048372268677, + 0, + 0.2320784330368042, + 0.0062233805656433105, + 0, + 1.2478563785552979, + 0.5479208827018738, + 0, + 0.06501156091690063, + 0.3741762936115265, + 0, + 0, + 0.31712013483047485, + 0.5228050947189331, + 0.3981531858444214, + 0, + 0, + 0.4854400157928467, + 0.3341655731201172, + 0.39207732677459717, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.3316766023635864, + 0, + 0, + 0.33435362577438354, + 0.1380615234375, + 0.7183249592781067, + 0.041296958923339844, + 0.7634149193763733, + 0, + 0.4028007984161377, + 0, + 0.6915435791015625, + 0, + 0, + 0, + 0.3831353187561035, + 0.05798754096031189, + 0.15244710445404053, + 0, + 0.03230410814285278, + 0.2039397656917572, + 0.6142292022705078, + 0.15542924404144287, + 0.07628917694091797, + 0.0812273919582367, + 0.15177401900291443, + 0.10224854946136475, + 0, + 0, + 2.8106069564819336, + 0.3994237184524536, + 0.6397127509117126, + 0, + 0.8949670791625977, + 0, + 0, + 0.18832790851593018, + 0.1450880765914917, + 0, + 0, + 0.6900937557220459, + 0, + 0.14745783805847168 + ], + [ + 0.12055802345275879, + 0.023864269256591797, + 0, + 0, + 0, + 0, + 0, + 0.3327372670173645, + 0.1789897382259369, + 1.1445300579071045, + 0, + 0, + 0, + 0, + 0, + 0.4361664652824402, + 0.09996795654296875, + 0.10051405429840088, + 0, + 0.4030296802520752, + 0.06672021746635437, + 0.6339577436447144, + 3.3947582244873047, + 0, + 0, + 0, + 0, + 0.9711236357688904, + 0, + 0.38066884875297546, + 0.4158353805541992, + 1.5344438552856445, + 0, + 0.19816407561302185, + 0, + 0.6646860241889954, + 0, + 0.16733816266059875, + 0, + 0, + 0, + 0.322623074054718, + 0, + 0.7314171195030212, + 0, + 0, + 0, + 0, + 0.043955981731414795, + 0, + 0, + 0, + 0, + 0, + 0, + 0.9436180591583252, + 0, + 0.29259607195854187, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.1570979356765747, + 0, + 0, + 0, + 1.1782727241516113, + 1.2431498765945435, + 0.32878363132476807, + 0, + 0.419150173664093, + 2.3304405212402344, + 0.8566346764564514, + 0, + 0, + 0, + 0.3841046392917633, + 0.10476112365722656, + 0, + 0.18140661716461182, + 0, + 0, + 0.6665420532226562, + 0, + 0, + 0.22877633571624756, + 0.9225524663925171, + 0, + 0.15886402130126953, + 0, + 0, + 0.02094721794128418, + 0, + 0, + 0, + 0.3046541213989258, + 0.2845715284347534, + 0, + 0, + 0.4244043231010437, + 0.164473295211792, + 0.30073386430740356, + 0.7123112678527832, + 0.1730642318725586, + 0, + 0.4041661322116852, + 0.39166414737701416, + 0, + 0, + 0.2103893756866455, + 0.007811635732650757, + 0.010994672775268555, + 0.03914850950241089, + 0, + 0, + 0.8430832624435425, + 0, + 0, + 0, + 0.15830591320991516, + 0.29398930072784424, + 0, + 0, + 0, + 0.5994948148727417, + 0.1704254150390625, + 0, + 0.4673898220062256, + 0, + 0.3204514980316162, + 0, + 0, + 0, + 0, + 0, + 0, + 0.8447363376617432 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Activations of Live SAE features at L5 S2 position per prompt" + }, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Feature Id" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Prompt" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "layer = 5\n", + "_, cache = model.run_with_cache(tokens)\n", + "s2_pos = 10\n", + "sae_acts = cache[utils.get_act_name('z', layer) + \".hook_sae_acts_post\"][:, s2_pos, :]\n", + "\n", + "live_feature_mask = sae_acts > 0\n", + "live_feature_union = live_feature_mask.any(dim=0)\n", + "\n", + "imshow(\n", + " sae_acts[:, live_feature_union],\n", + " title = \"Activations of Live SAE features at L5 S2 position per prompt\",\n", + " xaxis=\"Feature Id\", yaxis=\"Prompt\",\n", + " x=list(map(str, live_feature_union.nonzero().flatten().tolist())),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run with hooks" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally we can also use `run_with_hooks` and intervene on the added SAE's activations. To show a more complicated intervention, we'll try path patching the feature from the S-inhibition head's value vectors." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "model.set_use_split_qkv_input(True)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 141/141 [00:05<00:00, 26.94it/s]\n" + ] + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Feature Id: %{x}
Prompt Idx: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "x": [ + "46", + "345", + "702", + "1372", + "1755", + "1965", + "2457", + "2496", + "2646", + "2999", + "3047", + "4569", + "5132", + "5203", + "5508", + "5940", + "6144", + "6371", + "6515", + "6558", + "6812", + "7092", + "7515", + "7907", + "8063", + "8623", + "8737", + "8768", + "9096", + "9102", + "9186", + "9463", + "9746", + "9913", + "10581", + "10894", + "12109", + "12485", + "12764", + "12866", + "13063", + "13624", + "13707", + "13777", + "14844", + "15050", + "15170", + "15696", + "16178", + "16892", + "17156", + "17259", + "17497", + "17854", + "18043", + "18210", + "18318", + "18385", + "18440", + "18920", + "19183", + "19263", + "19442", + "19524", + "19573", + "20838", + "21151", + "21657", + "22108", + "23578", + "24091", + "24217", + "25792", + "26373", + "26410", + "27535", + "27787", + "27811", + "27960", + "28061", + "28241", + "28242", + "28254", + "28349", + "28977", + "29027", + "29482", + "29603", + "29700", + "29822", + "32177", + "32920", + "33320", + "33730", + "33966", + "34177", + "34334", + "34947", + "35403", + "35425", + "35579", + "35665", + "35815", + "36109", + "36172", + "36451", + "36767", + "36917", + "38570", + "39962", + "40409", + "40418", + "40661", + "41162", + "41185", + "41552", + "42024", + "42161", + "42437", + "42577", + "42882", + "42931", + "43035", + "43414", + "43643", + "43662", + "44203", + "44256", + "44452", + "44652", + "45179", + "45814", + "45984", + "46880", + "47117", + "47170", + "47231", + "47313", + "47680", + "48063", + "48703" + ], + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.0005645751953125, + 0.0000057220458984375, + 0.0000057220458984375, + 0.000339508056640625, + -0.003261566162109375, + 0.0000057220458984375, + 0.0000057220458984375, + 0.00069427490234375, + 0.0000057220458984375, + 0.0016155242919921875, + -0.09088897705078125, + 0.0000057220458984375, + 0.00011444091796875, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + -0.009515762329101562, + -0.0022525787353515625, + 0.0031604766845703125, + -0.0020704269409179688, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + -0.013577461242675781, + 0.0000057220458984375, + 0.0000057220458984375, + -0.0017032623291015625, + 0.0002880096435546875, + -0.00020503997802734375, + -0.0016231536865234375, + 0.00037860870361328125, + -0.0098114013671875, + -0.002185821533203125, + -0.0008878707885742188, + 0.0000057220458984375, + 0.0002346038818359375, + 0.0000057220458984375, + -0.000354766845703125, + 0.00036334991455078125, + 0.0000057220458984375, + -0.000988006591796875, + -0.00044918060302734375, + 0.0000057220458984375, + 0.0000057220458984375, + 0.005593299865722656, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + -0.005214691162109375, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + -0.000789642333984375, + 0.00010585784912109375, + 0.0000057220458984375, + -0.0059051513671875, + 0.0011091232299804688, + 0.0000057220458984375, + 0.026823997497558594, + 0.019052505493164062, + 0.0000057220458984375, + 0.0000152587890625, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0033597946166992188, + -0.020666122436523438, + -0.0041141510009765625, + -0.000011444091796875, + 0.00130462646484375, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + -0.01567840576171875, + 0.006500244140625, + 0.0000057220458984375, + 0.002086639404296875, + 0.00576019287109375, + 0.004245758056640625, + 0.006832122802734375, + 0.0006284713745117188, + 0.0000057220458984375, + 0.0000057220458984375, + -0.0009737014770507812, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + -0.0040988922119140625, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + -0.003326416015625, + 0.0000057220458984375, + 0.020755767822265625, + 0.0000057220458984375, + -0.0008373260498046875, + 0.0000057220458984375, + 0.007825851440429688, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + -0.002574920654296875, + 0.00151824951171875, + -0.00008678436279296875, + 0.0000057220458984375, + 0.0000057220458984375, + 0.001171112060546875, + -0.02040863037109375, + -0.0014247894287109375, + 0.0000057220458984375, + 0.0000057220458984375, + 0.00003814697265625, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0000057220458984375, + -0.3322334289550781, + 0.000579833984375, + 0.001293182373046875, + 0.0000057220458984375, + 0.0066661834716796875, + 0.0000171661376953125, + 0.0000057220458984375, + 0.0000057220458984375, + 0.0005435943603515625, + 0.00032806396484375, + 0.0000057220458984375, + 0.023120880126953125, + 0.0000057220458984375, + -0.0017566680908203125 + ], + [ + 0.00000762939453125, + 0.00000762939453125, + 0.0040073394775390625, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + 0.0022735595703125, + -0.0012683868408203125, + 0.00000762939453125, + 0.017993927001953125, + 0.011075973510742188, + 0.00000762939453125, + 0.00000762939453125, + -0.001407623291015625, + -0.000270843505859375, + -0.010431289672851562, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + -0.6347770690917969, + 0.00000762939453125, + 0.00000762939453125, + 0.0005435943603515625, + 0.00000762939453125, + 0.09274864196777344, + 0.008495330810546875, + 0.00000762939453125, + 0.00000762939453125, + -0.08464431762695312, + 0.00000762939453125, + 0.028835296630859375, + 0.01250457763671875, + 0.029806137084960938, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + 0.012714385986328125, + 0.00000762939453125, + 0.00000762939453125, + -0.0004444122314453125, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + 0.003757476806640625, + 0.00000762939453125, + 0.0025272369384765625, + 0.0013427734375, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + 0.07260704040527344, + 0.00000762939453125, + 0.01149749755859375, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + -0.000213623046875, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + -0.016370773315429688, + 0.00000762939453125, + 0.00000762939453125, + -0.00792694091796875, + 0.03365135192871094, + -0.004932403564453125, + 0.005069732666015625, + 0.00000762939453125, + 0.0031223297119140625, + -0.5932121276855469, + -0.0007534027099609375, + 0.05148506164550781, + 0.00000762939453125, + 0.00000762939453125, + 0.014024734497070312, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + -0.11317634582519531, + -0.0026416778564453125, + 0.00000762939453125, + -0.006038665771484375, + 0.00672149658203125, + 0.00000762939453125, + 0.00000762939453125, + 0.000064849853515625, + 0.00000762939453125, + 0.00000762939453125, + 0.0005397796630859375, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + -0.0024967193603515625, + 0.00000762939453125, + 0.00000762939453125, + 0.016933441162109375, + 0.00000762939453125, + -0.0049343109130859375, + 0.00000762939453125, + 0.00000762939453125, + -0.00244140625, + -0.00624847412109375, + 0.018770217895507812, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + -0.001132965087890625, + 0.00000762939453125, + 0.00000762939453125, + 0.1962738037109375, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + -0.0005283355712890625, + 0.0070934295654296875, + 0.00000762939453125, + 0.00000762939453125, + 0.10946464538574219, + 0.05410957336425781, + -0.0026397705078125, + 0.00000762939453125, + 0.005107879638671875, + 0.006359100341796875, + -0.04090118408203125, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + 0.00000762939453125, + 0.06792449951171875 + ], + [ + 0.0032672882080078125, + 0.00000667572021484375, + 0.00000667572021484375, + 0.0026044845581054688, + 0.00000667572021484375, + 0.00000667572021484375, + 0.00000667572021484375, + -0.0013751983642578125, + 0.00000667572021484375, + 0.018096923828125, + -0.29747962951660156, + 0.00000667572021484375, + 0.00159454345703125, + 0.00000667572021484375, + 0.00000667572021484375, + 0.00000667572021484375, + -0.00185394287109375, + 0.00000667572021484375, + 0.000064849853515625, + 0.00000667572021484375, + 0.00000667572021484375, + 0.0004253387451171875, + 0.02138805389404297, + 0.000370025634765625, + 0.00000667572021484375, + -0.0002880096435546875, + 0.000560760498046875, + -0.03230476379394531, + -0.02060699462890625, + 0.020964622497558594, + -0.0022487640380859375, + 0.00000667572021484375, + 0.001964569091796875, + 0.00000667572021484375, + -0.07773113250732422, + -0.042862892150878906, + 0.00027751922607421875, + 0.00000667572021484375, + -0.0020580291748046875, + 0.001407623291015625, + 0.00000667572021484375, + -0.0008306503295898438, + 0.00371551513671875, + -0.08299636840820312, + -0.0030012130737304688, + -0.0021905899047851562, + 0.00000667572021484375, + 0.011617660522460938, + 0.00000667572021484375, + -0.0000152587890625, + 0.005359649658203125, + 0.00000667572021484375, + -0.0042018890380859375, + 0.00000667572021484375, + 0.0008802413940429688, + -0.049579620361328125, + 0.010822296142578125, + 0.00000667572021484375, + 0.00000667572021484375, + 0.00000667572021484375, + -0.014369964599609375, + 0.00000667572021484375, + 0.00000667572021484375, + -0.0016632080078125, + 0.00000667572021484375, + 0.0035800933837890625, + 0.024021148681640625, + -0.04512596130371094, + -0.0006885528564453125, + 0.00000667572021484375, + 0.013338088989257812, + 0.06371307373046875, + 0.00000667572021484375, + 0.000629425048828125, + 0.00002002716064453125, + 0.00000667572021484375, + 0.015112876892089844, + -0.05301094055175781, + 0.00000667572021484375, + -0.0011320114135742188, + 0.0012521743774414062, + 0.00000667572021484375, + 0.00000667572021484375, + 0.00000667572021484375, + -0.038700103759765625, + -0.0035238265991210938, + 0.00000667572021484375, + 0.00608062744140625, + -0.011157035827636719, + 0.004566192626953125, + 0.0062274932861328125, + 0.00000667572021484375, + 0.00000667572021484375, + 0.00000667572021484375, + -0.0015010833740234375, + 0.00000667572021484375, + 0.00000667572021484375, + 0.00000667572021484375, + -0.010572433471679688, + 0.00000667572021484375, + 0.00000667572021484375, + 0.00000667572021484375, + 0.00000667572021484375, + -0.016614913940429688, + 0.00000667572021484375, + 0.030905723571777344, + 0.00000667572021484375, + -0.015107154846191406, + 0.00000667572021484375, + 0.012714385986328125, + -0.0009021759033203125, + -0.00067138671875, + 0.0006847381591796875, + -0.005970954895019531, + 0.00000667572021484375, + 0.000392913818359375, + 0.00000667572021484375, + 0.00000667572021484375, + 0.00000667572021484375, + -0.20943737030029297, + 0.00000667572021484375, + 0.00000667572021484375, + 0.00000667572021484375, + 0.00000667572021484375, + 0.00000667572021484375, + 0.0024538040161132812, + -0.00016117095947265625, + -0.6926145553588867, + 0.00000667572021484375, + -0.006705284118652344, + 0.00000667572021484375, + 0.013433456420898438, + 0.00000667572021484375, + 0.00000667572021484375, + 0.0039653778076171875, + 0.00000667572021484375, + 0.00000667572021484375, + 0.00000667572021484375, + 0.05192756652832031, + -0.00046539306640625, + -0.0010156631469726562 + ], + [ + -0.00000667572021484375, + -0.00000667572021484375, + 0.0073337554931640625, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + 0.0017852783203125, + 0.021762847900390625, + 0.023838043212890625, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.0093231201171875, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + 0.00185394287109375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.7318296432495117, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.06693649291992188, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + 0.04135417938232422, + -0.00000667572021484375, + 0.0012073516845703125, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.0023708343505859375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + 0.00597381591796875, + -0.00000667572021484375, + 0.0001049041748046875, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + 0.04203224182128906, + -0.000133514404296875, + 0.0032367706298828125, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + 0.053966522216796875, + -0.017469406127929688, + -0.00000667572021484375, + -0.00000667572021484375, + 0.0032787322998046875, + -0.8294486999511719, + -0.00000667572021484375, + 0.042545318603515625, + -0.00000667572021484375, + -0.00000667572021484375, + 0.006573677062988281, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.1314229965209961, + -0.00000667572021484375, + -0.00000667572021484375, + -0.022655487060546875, + 0.0008211135864257812, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.019756317138671875, + -0.0028676986694335938, + -0.00000667572021484375, + -0.00000667572021484375, + 0.0034084320068359375, + -0.00000667572021484375, + 0.0000171661376953125, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.0022497177124023438, + 0.00191497802734375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + 0.09851455688476562, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.003956794738769531, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + 0.0011348724365234375, + -0.00000667572021484375, + 0.0007839202880859375, + -0.00000667572021484375, + -0.0783843994140625, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375, + -0.00000667572021484375 + ], + [ + -0.00021839141845703125, + 0.00000286102294921875, + 0.00000286102294921875, + -0.00017833709716796875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + -0.00004863739013671875, + 0.00000286102294921875, + 0.0024118423461914062, + -0.1688375473022461, + 0.00000286102294921875, + 0.0005617141723632812, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + -0.0027265548706054688, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + -0.009179115295410156, + 0.011872291564941406, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.006281852722167969, + 0.00000286102294921875, + 0.011416435241699219, + 0.014454841613769531, + 0.00000286102294921875, + 0.00000286102294921875, + -0.00018596649169921875, + 0.00000286102294921875, + 0.012002944946289062, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + -0.0023813247680664062, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.000225067138671875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + -0.0033779144287109375, + 0.00000286102294921875, + 0.00000286102294921875, + -0.0017099380493164062, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + -0.05732154846191406, + 0.00000286102294921875, + 0.00000286102294921875, + 0.016089439392089844, + 0.07070255279541016, + 0.00000286102294921875, + 0.014483451843261719, + 0.00000286102294921875, + 0.00000286102294921875, + 0.0017747879028320312, + -0.024786949157714844, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00012302398681640625, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + -0.0092620849609375, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00185394287109375, + -0.00025177001953125, + 0.008860588073730469, + 0.006030082702636719, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00017833709716796875, + 0.00000286102294921875, + 0.00000286102294921875, + -0.001644134521484375, + 0.0026140213012695312, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + -0.0013418197631835938, + 0.00000286102294921875, + 0.037514686584472656, + 0.00000286102294921875, + -0.00038433074951171875, + 0.00000286102294921875, + 0.01964282989501953, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.005845069885253906, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + -0.04890918731689453, + 0.00000286102294921875, + 0.008494377136230469, + -0.00026988983154296875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + -0.37475109100341797, + 0.00000286102294921875, + 0.004479408264160156, + -0.0015649795532226562, + 0.00385284423828125, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00000286102294921875, + 0.00030803680419921875, + 0.06992149353027344, + 0.00000286102294921875, + 0.00000286102294921875 + ], + [ + 0.0000019073486328125, + 0.0000019073486328125, + 0.0018415451049804688, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0016222000122070312, + 0.023705482482910156, + 0.07090950012207031, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + -0.021169662475585938, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + -1.3031587600708008, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.08781909942626953, + 0.0000019073486328125, + 0.0000019073486328125, + 0.016541481018066406, + -0.10686969757080078, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.04713726043701172, + 0.0000019073486328125, + 0.002704620361328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.00046062469482421875, + 0.0000019073486328125, + 0.0000019073486328125, + -0.01665210723876953, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0031337738037109375, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0426177978515625, + 0.0000019073486328125, + 0.018036842346191406, + 0.0000019073486328125, + -0.001964569091796875, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0788869857788086, + -0.03188610076904297, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + -1.4886322021484375, + 0.0000019073486328125, + 0.0885171890258789, + 0.0000019073486328125, + 0.0000019073486328125, + 0.01448822021484375, + 0.0000019073486328125, + -0.0066547393798828125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + -0.045001983642578125, + 0.0000019073486328125, + 0.0000019073486328125, + -0.017113685607910156, + 0.010157585144042969, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + -0.0030698776245117188, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + -0.0001583099365234375, + -0.004227638244628906, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + -0.008556365966796875, + 0.007357597351074219, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.13220977783203125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + -0.013454437255859375, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.033707618713378906, + 0.006083488464355469, + 0.0000019073486328125, + 0.0014142990112304688, + 0.0000019073486328125, + -0.04172039031982422, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.0000019073486328125, + 0.03891944885253906 + ], + [ + 0.0013017654418945312, + 0.00000858306884765625, + 0.00000858306884765625, + 0.0019664764404296875, + 0.00000858306884765625, + 0.0035953521728515625, + 0.00000858306884765625, + 0.0006504058837890625, + 0.00000858306884765625, + 0.0031061172485351562, + -0.07722282409667969, + -0.0011444091796875, + 0.00000858306884765625, + 0.00000858306884765625, + 0.00000858306884765625, + 0.0056209564208984375, + 0.00003147125244140625, + 0.00000858306884765625, + 0.00000858306884765625, + 0.00000858306884765625, + 0.00000858306884765625, + 0.015537261962890625, + 0.001983642578125, + 0.00000858306884765625, + 0.00000858306884765625, + 0.00000858306884765625, + 0.00000858306884765625, + -0.025580406188964844, + -0.005356788635253906, + 0.016262054443359375, + -0.005573272705078125, + 0.00000858306884765625, + 0.00000858306884765625, + 0.00000858306884765625, + -0.0113525390625, + 0.013624191284179688, + 0.00000858306884765625, + 0.000110626220703125, + 0.00000858306884765625, + 0.00000858306884765625, + 0.00000858306884765625, + 0.000293731689453125, + 0.00000858306884765625, + 0.026404380798339844, + 0.00000858306884765625, + 0.0005817413330078125, + 0.00007343292236328125, + 0.0010223388671875, + 0.00000858306884765625, + 0.00000858306884765625, + 0.00000858306884765625, + 0.00000858306884765625, + -0.009862899780273438, + 0.00000858306884765625, + 0.00000858306884765625, + -0.006870269775390625, + 0.00435638427734375, + 0.00000858306884765625, + 0.000232696533203125, + 0.00000858306884765625, + -0.00616455078125, + 0.00000858306884765625, + 0.00033283233642578125, + 0.00000858306884765625, + -0.0016880035400390625, + 0.00286102294921875, + 0.00000858306884765625, + -0.01665496826171875, + 0.008039474487304688, + 0.00000858306884765625, + 0.03484916687011719, + 0.018899917602539062, + 0.00000858306884765625, + 0.00034809112548828125, + -0.0000095367431640625, + 0.00000858306884765625, + 0.022369384765625, + -0.00615692138671875, + 0.00000858306884765625, + -0.00008392333984375, + -0.0018634796142578125, + 0.00000858306884765625, + 0.00000858306884765625, + -0.0065765380859375, + -0.00798797607421875, + 0.007740974426269531, + 0.00000858306884765625, + 0.00000858306884765625, + 0.0047855377197265625, + 0.00484466552734375, + 0.006256103515625, + 0.00000858306884765625, + 0.00000858306884765625, + 0.00000858306884765625, + 0.00000858306884765625, + 0.00000858306884765625, + 0.00000858306884765625, + 0.00000858306884765625, + -0.005949974060058594, + 0.00000858306884765625, + 0.00000858306884765625, + -0.0056934356689453125, + -0.0057353973388671875, + -0.005535125732421875, + 0.00028228759765625, + 0.0137786865234375, + 0.00000858306884765625, + 0.0026874542236328125, + 0.00000858306884765625, + 0.01714324951171875, + 0.00000858306884765625, + 0.00000858306884765625, + 0.00000858306884765625, + 0.00165557861328125, + 0.0006313323974609375, + -0.00090789794921875, + 0.00000858306884765625, + 0.00016021728515625, + 0.00311279296875, + -0.04284191131591797, + -0.00058746337890625, + 0.0028972625732421875, + -0.001148223876953125, + 0.0013751983642578125, + -0.0005426406860351562, + 0.00000858306884765625, + 0.00000858306884765625, + -0.29439735412597656, + 0.0019617080688476562, + 0.018915176391601562, + 0.00000858306884765625, + 0.009466171264648438, + 0.00000858306884765625, + 0.00000858306884765625, + 0.0011997222900390625, + 0.001117706298828125, + 0.00000858306884765625, + 0.00000858306884765625, + 0.02146148681640625, + 0.00000858306884765625, + -0.0012531280517578125 + ], + [ + 0.005249977111816406, + 0.0015926361083984375, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + 0.0021581649780273438, + 0.01883697509765625, + 0.0733022689819336, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.04300212860107422, + 0.0030794143676757812, + -0.0017910003662109375, + -0.00000286102294921875, + 0.016645431518554688, + -0.021103858947753906, + 0.013091087341308594, + -1.6041021347045898, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + 0.3691072463989258, + -0.00000286102294921875, + -0.01113128662109375, + 0.09581279754638672, + -0.11300373077392578, + -0.00000286102294921875, + 0.047149658203125, + -0.00000286102294921875, + 0.053336143493652344, + -0.00000286102294921875, + 0.004380226135253906, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.008252143859863281, + -0.00000286102294921875, + -0.018776893615722656, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + 0.0016984939575195312, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + 0.10451030731201172, + -0.00000286102294921875, + 0.010519981384277344, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.014172554016113281, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + 0.07860374450683594, + -0.047211647033691406, + 0.010329246520996094, + -0.00000286102294921875, + 0.02579212188720703, + -1.5303049087524414, + -0.020979881286621094, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + 0.05430316925048828, + 0.006442070007324219, + -0.00000286102294921875, + 0.035637855529785156, + -0.00000286102294921875, + -0.00000286102294921875, + -0.0784912109375, + -0.00000286102294921875, + -0.00000286102294921875, + -0.020351409912109375, + 0.02591228485107422, + -0.00000286102294921875, + 0.0030069351196289062, + -0.00000286102294921875, + -0.00000286102294921875, + 0.0012063980102539062, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.04748249053955078, + -0.00510406494140625, + -0.00000286102294921875, + -0.00000286102294921875, + 0.03345203399658203, + -0.0017213821411132812, + -0.008072853088378906, + 0.014155387878417969, + -0.003909111022949219, + -0.00000286102294921875, + -0.02114105224609375, + 0.021615028381347656, + -0.00000286102294921875, + -0.00000286102294921875, + 0.011925697326660156, + 0.0005092620849609375, + 0.000263214111328125, + -0.00007343292236328125, + -0.00000286102294921875, + -0.00000286102294921875, + 0.2987813949584961, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.011395454406738281, + 0.01917552947998047, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + 0.12213993072509766, + 0.0026998519897460938, + -0.00000286102294921875, + 0.009751319885253906, + -0.00000286102294921875, + -0.12412357330322266, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + -0.00000286102294921875, + 0.15506553649902344 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Change in logit diff when path patching features from S_inhibition heads values per prompts" + }, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Feature Id" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Prompt Idx" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def path_patch_v_input(v_input, hook, feature_dirs, pos, head_index):\n", + " v_input[:, pos, head_index, :] = v_input[:, pos, head_index, :] - feature_dirs\n", + " return v_input\n", + "\n", + "\n", + "s_inhib_heads = [(7, 3), (7, 9), (8,6), (8,10)]\n", + "\n", + "results = torch.zeros(tokens.shape[0], all_live_features.shape[0])\n", + "\n", + "W_O_cat = einops.rearrange(\n", + " model.W_O,\n", + " \"n_layers n_heads d_head d_model -> n_layers (n_heads d_head) d_model\"\n", + ")\n", + "\n", + "for feature_id in tqdm.tqdm(all_live_features):\n", + " feature_id = feature_id.item()\n", + " feature_acts = cache[utils.get_act_name('z', abl_layer) + \".hook_sae_acts_post\"][:, abl_pos, feature_id] # [batch]\n", + " feature_dirs = (feature_acts.unsqueeze(-1) * sae.W_dec[feature_id]) @ W_O_cat[abl_layer]\n", + " hook_fns = [\n", + " (utils.get_act_name('v_input', layer), partial(path_patch_v_input, feature_dirs=feature_dirs, pos=abl_pos, head_index=head)) for (layer, head) in s_inhib_heads\n", + " ]\n", + " path_patched_logits = model.run_with_hooks(\n", + " tokens,\n", + " return_type=\"logits\",\n", + " fwd_hooks=hook_fns\n", + " )\n", + "\n", + " path_patched_logit_diff = logits_to_ave_logit_diff(path_patched_logits, answer_tokens, per_prompt=True)\n", + " results[:, fid_to_idx[feature_id]] = path_patched_logit_diff - clean_sae_baseline_per_prompt\n", + "\n", + "imshow(\n", + " results, \n", + " title=f\"Change in logit diff when path patching features from S_inhibition heads values per prompts\",\n", + " xaxis=\"Feature Id\", yaxis=\"Prompt Idx\", x=list(map(str, all_live_features.tolist()))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Reset SAEs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "One major footgun is forgetting about an SAE that you previously attached with `add_sae`. Similar to TransformerLens `reset_hooks`, you can always reset SAEs you've added with `model.reset_saes()`. You can also pass in a list of activation names to only reset a subset of attached SAEs." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Attached SAEs before reset_saes: {'blocks.5.attn.hook_z': HookedSAE(\n", + " (hook_sae_input): HookPoint()\n", + " (hook_sae_acts_pre): HookPoint()\n", + " (hook_sae_acts_post): HookPoint()\n", + " (hook_sae_recons): HookPoint()\n", + " (hook_sae_error): HookPoint()\n", + " (hook_sae_output): HookPoint()\n", + ")}\n", + "Attached SAEs after reset_saes: {}\n" + ] + } + ], + "source": [ + "print(\"Attached SAEs before reset_saes:\", model.acts_to_saes)\n", + "model.reset_saes()\n", + "print(\"Attached SAEs after reset_saes:\", model.acts_to_saes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that the HookedSAETransformer API is generally designed to closely match TransformerLens hooks API." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Error Nodes" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Recent exciting work from [Marks et al.](https://arxiv.org/abs/2403.19647v2) demonstrated the use of \"error nodes\" in SAE circuit analysis. The idea is that for some input activation x, SAE(x) = x_reconstruct is an approximation of x, but we can define an error_term such that x = x_reconstruct + error_term.\n", + "\n", + "This seems useful: instead of replacing x with x_reconstruct, which might break everything and make our circuit analysis janky, we can just re-write x as a function of the SAE features, bias, and error term, which gives us access to all of the SAE features but without breaking performance. \n", + "\n", + "Additionally, we can compare interventions on SAE features to the same intervention on the error term to get a better sense of how much the SAE features have actually captured.\n", + "\n", + "To use error terms with HookedSAEs, you can set `hooked_sae.cfg.use_error_term = True`, or initialize it to True in the config. Note HookedSAEConfig sets this to False by default." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Attached SAEs after adding l5_sae_with_error: {'blocks.5.attn.hook_z': HookedSAE(\n", + " (hook_sae_input): HookPoint()\n", + " (hook_sae_acts_pre): HookPoint()\n", + " (hook_sae_acts_post): HookPoint()\n", + " (hook_sae_recons): HookPoint()\n", + " (hook_sae_error): HookPoint()\n", + " (hook_sae_output): HookPoint()\n", + ")}\n" + ] + } + ], + "source": [ + "import copy\n", + "l5_sae = hook_name_to_sae[utils.get_act_name('z', 5)]\n", + "l5_sae_with_error = copy.deepcopy(l5_sae)\n", + "l5_sae_with_error.cfg.use_error_term=True\n", + "model.add_sae(l5_sae_with_error)\n", + "print(\"Attached SAEs after adding l5_sae_with_error:\", model.acts_to_saes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now the output of each attached SAE will be SAE(x) + error_term = x. We can sanity check this by confirming that running with SAEs produces the same logits without SAEs." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "logits_with_saes = model(tokens)\n", + "logit_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens)\n", + "\n", + "assert torch.allclose(logits_with_saes, original_logits, atol=1e-4)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can compare ablations of each feature to ablating the error node. We'll start by ablating each feature on each prompt, and then the error nodes. We'll append the effects from ablating error nodes to the rightmost column on the heatmap:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 141/141 [00:04<00:00, 32.33it/s]\n" + ] + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Feature Idx: %{x}
Prompt Idx: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "x": [ + "46", + "345", + "702", + "1372", + "1755", + "1965", + "2457", + "2496", + "2646", + "2999", + "3047", + "4569", + "5132", + "5203", + "5508", + "5940", + "6144", + "6371", + "6515", + "6558", + "6812", + "7092", + "7515", + "7907", + "8063", + "8623", + "8737", + "8768", + "9096", + "9102", + "9186", + "9463", + "9746", + "9913", + "10581", + "10894", + "12109", + "12485", + "12764", + "12866", + "13063", + "13624", + "13707", + "13777", + "14844", + "15050", + "15170", + "15696", + "16178", + "16892", + "17156", + "17259", + "17497", + "17854", + "18043", + "18210", + "18318", + "18385", + "18440", + "18920", + "19183", + "19263", + "19442", + "19524", + "19573", + "20838", + "21151", + "21657", + "22108", + "23578", + "24091", + "24217", + "25792", + "26373", + "26410", + "27535", + "27787", + "27811", + "27960", + "28061", + "28241", + "28242", + "28254", + "28349", + "28977", + "29027", + "29482", + "29603", + "29700", + "29822", + "32177", + "32920", + "33320", + "33730", + "33966", + "34177", + "34334", + "34947", + "35403", + "35425", + "35579", + "35665", + "35815", + "36109", + "36172", + "36451", + "36767", + "36917", + "38570", + "39962", + "40409", + "40418", + "40661", + "41162", + "41185", + "41552", + "42024", + "42161", + "42437", + "42577", + "42882", + "42931", + "43035", + "43414", + "43643", + "43662", + "44203", + "44256", + "44452", + "44652", + "45179", + "45814", + "45984", + "46880", + "47117", + "47170", + "47231", + "47313", + "47680", + "48063", + "48703", + "error" + ], + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.0012617111206054688, + -9.5367431640625e-7, + -9.5367431640625e-7, + 0.0016908645629882812, + -0.0002231597900390625, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.00029659271240234375, + -9.5367431640625e-7, + -0.03279590606689453, + -0.07254886627197266, + -9.5367431640625e-7, + 0.00013065338134765625, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.014922142028808594, + -0.0044403076171875, + 0.0007047653198242188, + -0.00428009033203125, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.039069175720214844, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.007334709167480469, + 0.00033092498779296875, + -0.0017004013061523438, + 0.0026845932006835938, + 0.00043010711669921875, + -0.11128997802734375, + -0.0038976669311523438, + -0.006033897399902344, + -9.5367431640625e-7, + -0.00027751922607421875, + -9.5367431640625e-7, + 0.0006570816040039062, + -0.0004291534423828125, + -9.5367431640625e-7, + -0.0035734176635742188, + -0.0033063888549804688, + -9.5367431640625e-7, + -9.5367431640625e-7, + 0.0033960342407226562, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.0030546188354492188, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.0000972747802734375, + -0.0001811981201171875, + -9.5367431640625e-7, + -0.004569053649902344, + -0.013583183288574219, + -9.5367431640625e-7, + 0.02047252655029297, + -0.02572154998779297, + -9.5367431640625e-7, + -0.0006608963012695312, + -9.5367431640625e-7, + -9.5367431640625e-7, + 0.02255725860595703, + -0.05519580841064453, + -0.0033473968505859375, + -0.0000057220458984375, + -0.0026073455810546875, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.02097320556640625, + 0.008440971374511719, + -9.5367431640625e-7, + -0.004597663879394531, + 0.00159454345703125, + 0.0001544952392578125, + 0.005199432373046875, + 0.0007762908935546875, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.0032625198364257812, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.015192985534667969, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.018138885498046875, + -9.5367431640625e-7, + 0.010298728942871094, + -9.5367431640625e-7, + -0.0031423568725585938, + -9.5367431640625e-7, + 0.004242897033691406, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.010041236877441406, + 0.0010347366333007812, + 0.006011962890625, + -9.5367431640625e-7, + -9.5367431640625e-7, + 0.00301361083984375, + -0.04584026336669922, + 0.0002079010009765625, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.0002574920654296875, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.45942211151123047, + -0.0008325576782226562, + 0.00041484832763671875, + -9.5367431640625e-7, + -0.023777008056640625, + 0.0000514984130859375, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.00030422210693359375, + 0.0006666183471679688, + -9.5367431640625e-7, + 0.004633903503417969, + -9.5367431640625e-7, + -0.008234977722167969, + -0.07327461242675781 + ], + [ + 0.000003814697265625, + 0.000003814697265625, + -0.00208282470703125, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + -0.0012912750244140625, + -0.01760101318359375, + 0.000003814697265625, + 0.057277679443359375, + 0.013429641723632812, + 0.000003814697265625, + 0.000003814697265625, + -0.0000457763671875, + -0.0027828216552734375, + -0.0055084228515625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + -0.2744255065917969, + 0.000003814697265625, + 0.000003814697265625, + 0.0021514892578125, + 0.000003814697265625, + 0.06994247436523438, + 0.0048542022705078125, + 0.000003814697265625, + 0.000003814697265625, + -0.0567169189453125, + 0.000003814697265625, + 0.012315750122070312, + 0.0066585540771484375, + 0.07937240600585938, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.028867721557617188, + 0.000003814697265625, + 0.000003814697265625, + 0.0074901580810546875, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.009624481201171875, + 0.000003814697265625, + -0.009510040283203125, + 0.0032100677490234375, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.10918617248535156, + 0.000003814697265625, + 0.026102066040039062, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000946044921875, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + -0.041675567626953125, + 0.000003814697265625, + 0.000003814697265625, + -0.0066776275634765625, + 0.03926849365234375, + 0.03615379333496094, + 0.027612686157226562, + 0.000003814697265625, + -0.0004673004150390625, + -0.1435985565185547, + -0.00030517578125, + 0.059326171875, + 0.000003814697265625, + 0.000003814697265625, + 0.020435333251953125, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + -0.11923980712890625, + -0.009393692016601562, + 0.000003814697265625, + 0.011783599853515625, + 0.06122589111328125, + 0.000003814697265625, + 0.000003814697265625, + 0.0002918243408203125, + 0.000003814697265625, + 0.000003814697265625, + 0.001491546630859375, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + -0.0050716400146484375, + 0.000003814697265625, + 0.000003814697265625, + 0.025064468383789062, + 0.000003814697265625, + -0.0467529296875, + 0.000003814697265625, + 0.000003814697265625, + 0.0014934539794921875, + 0.00043487548828125, + 0.028188705444335938, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.001995086669921875, + 0.000003814697265625, + 0.000003814697265625, + 0.13014602661132812, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.0005893707275390625, + 0.012182235717773438, + 0.000003814697265625, + 0.000003814697265625, + 0.11103057861328125, + 0.042850494384765625, + 0.030099868774414062, + 0.000003814697265625, + -0.0047321319580078125, + 0.0000133514404296875, + -0.0320587158203125, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.031030654907226562, + -0.002635955810546875 + ], + [ + 0.007018089294433594, + 0, + 0, + 0.0028057098388671875, + 0, + 0, + 0, + -0.010999679565429688, + 0, + -0.1419973373413086, + -0.24188613891601562, + 0, + 0.0003147125244140625, + 0, + 0, + 0, + 0.009432792663574219, + 0, + -0.000125885009765625, + 0, + 0, + 0.00017070770263671875, + 0.011651992797851562, + -0.00225830078125, + 0, + -0.0014581680297851562, + 0.00020122528076171875, + -0.030771255493164062, + -0.03744316101074219, + -0.034499168395996094, + -0.00374603271484375, + 0, + 0.0011348724365234375, + 0, + -0.0302276611328125, + -0.08229637145996094, + -0.00048160552978515625, + 0, + -0.00640869140625, + 0.0001277923583984375, + 0, + -0.0008974075317382812, + 0.00022983551025390625, + -0.2322559356689453, + -0.0050449371337890625, + -0.010677337646484375, + 0, + 0.014942169189453125, + 0, + 0.0008764266967773438, + 0.00417327880859375, + 0, + -0.015301704406738281, + 0, + -0.0008974075317382812, + -0.04426097869873047, + 0.005242347717285156, + 0, + 0, + 0, + -0.009447097778320312, + 0, + 0, + -0.0011806488037109375, + 0, + -0.0045909881591796875, + 0.015285491943359375, + -0.034976959228515625, + -0.013401985168457031, + 0, + 0.1357421875, + -0.09111690521240234, + 0, + 0.00013065338134765625, + 0.0002460479736328125, + 0, + 0.04656982421875, + -0.09346866607666016, + 0, + -0.005030632019042969, + 0.0001125335693359375, + 0, + 0, + 0, + -0.07491683959960938, + 0.006598472595214844, + 0, + -0.014060020446777344, + -0.008306503295898438, + -0.0054874420166015625, + -0.0004930496215820312, + 0, + 0, + 0, + -0.008953094482421875, + 0, + 0, + 0, + -0.03713417053222656, + 0, + 0, + 0, + 0, + -0.028200149536132812, + 0, + 0.036255836486816406, + 0, + -0.03178215026855469, + 0, + -0.012192726135253906, + -0.002147674560546875, + -0.0005474090576171875, + -0.0021409988403320312, + -0.030725479125976562, + 0, + 0.0008029937744140625, + 0, + 0, + 0, + -0.29135894775390625, + 0, + 0, + 0, + 0, + 0, + -0.0027914047241210938, + -0.00022125244140625, + -0.8653240203857422, + 0, + -0.05593109130859375, + 0, + -0.04123210906982422, + 0, + 0, + 0.015351295471191406, + 0, + 0, + 0, + 0.018423080444335938, + -0.0000476837158203125, + -0.0023584365844726562, + -0.3282146453857422 + ], + [ + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.0001983642578125, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + -0.010341644287109375, + 0.07198715209960938, + 0.14725303649902344, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.0002918243408203125, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.011704444885253906, + 9.5367431640625e-7, + 9.5367431640625e-7, + -0.3150959014892578, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + -0.039947509765625, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.15607547760009766, + 9.5367431640625e-7, + 0.09917640686035156, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.019521713256835938, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.012205123901367188, + 9.5367431640625e-7, + -0.0005893707275390625, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.07062149047851562, + 0.000492095947265625, + 0.014776229858398438, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.0557098388671875, + 0.15409469604492188, + 9.5367431640625e-7, + 9.5367431640625e-7, + -0.0007076263427734375, + -0.24256324768066406, + 9.5367431640625e-7, + 0.0858917236328125, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.007343292236328125, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + -0.11646080017089844, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.05528736114501953, + 0.0847921371459961, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.00428009033203125, + -0.0056171417236328125, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.0066967010498046875, + 9.5367431640625e-7, + -0.006005287170410156, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.01735687255859375, + -0.0037336349487304688, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.09533309936523438, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.009324073791503906, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.007989883422851562, + 9.5367431640625e-7, + 0.0064525604248046875, + 9.5367431640625e-7, + -0.06574440002441406, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.5591859817504883 + ], + [ + -0.0009012222290039062, + 0.00001239776611328125, + 0.00001239776611328125, + -0.0006313323974609375, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.000461578369140625, + 0.00001239776611328125, + -0.055993080139160156, + -0.24974536895751953, + 0.00001239776611328125, + 0.0011262893676757812, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.0025796890258789062, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.030013084411621094, + -0.012925148010253906, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.0253448486328125, + 0.00001239776611328125, + 0.0012464523315429688, + 0.021536827087402344, + 0.00001239776611328125, + 0.00001239776611328125, + -0.00009822845458984375, + 0.00001239776611328125, + -0.09924793243408203, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.006188392639160156, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.0010576248168945312, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.0008172988891601562, + 0.00001239776611328125, + 0.00001239776611328125, + -0.0020704269409179688, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.09985160827636719, + 0.00001239776611328125, + 0.00001239776611328125, + 0.036945343017578125, + 0.025011062622070312, + 0.00001239776611328125, + 0.004599571228027344, + 0.00001239776611328125, + 0.00001239776611328125, + 0.027939796447753906, + -0.07974910736083984, + 0.00001239776611328125, + 0.00001239776611328125, + -0.00038242340087890625, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.035175323486328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.0047245025634765625, + -0.008166313171386719, + -0.008578300476074219, + 0.0018529891967773438, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.0016679763793945312, + 0.00001239776611328125, + 0.00001239776611328125, + -0.0028676986694335938, + -0.04880046844482422, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.0053462982177734375, + 0.00001239776611328125, + 0.1658468246459961, + 0.00001239776611328125, + -0.0024824142456054688, + 0.00001239776611328125, + 0.025139808654785156, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.027915000915527344, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.14544200897216797, + 0.00001239776611328125, + 0.020270347595214844, + 0.007473945617675781, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.8424196243286133, + 0.00001239776611328125, + -0.007409095764160156, + -0.00318145751953125, + -0.015982627868652344, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00034046173095703125, + 0.10727787017822266, + 0.00001239776611328125, + 0.00001239776611328125, + -0.5388059616088867 + ], + [ + -0.00000762939453125, + -0.00000762939453125, + 0.0019397735595703125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.022940635681152344, + 0.07428932189941406, + 0.29994869232177734, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.016974449157714844, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.4772310256958008, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.05463600158691406, + -0.00000762939453125, + -0.00000762939453125, + 0.004734992980957031, + -0.12352275848388672, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.15236186981201172, + -0.00000762939453125, + 0.27855396270751953, + -0.00000762939453125, + -0.00000762939453125, + 0.001430511474609375, + -0.00000762939453125, + -0.00000762939453125, + -0.016387939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.0008668899536132812, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.06814861297607422, + -0.00000762939453125, + 0.00351715087890625, + -0.00000762939453125, + 0.0061588287353515625, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.02480602264404297, + 0.31668567657470703, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.4413900375366211, + -0.00000762939453125, + 0.1517791748046875, + -0.00000762939453125, + -0.00000762939453125, + 0.010898590087890625, + -0.00000762939453125, + 0.006583213806152344, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.04946422576904297, + -0.00000762939453125, + -0.00000762939453125, + 0.040429115295410156, + 0.1020956039428711, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.0008649826049804688, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.0054836273193359375, + -0.010519981384277344, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.004588127136230469, + -0.006558418273925781, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.07750797271728516, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.03235149383544922, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.02773571014404297, + 0.08978557586669922, + -0.00000762939453125, + 0.008780479431152344, + -0.00000762939453125, + -0.0327301025390625, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.035370826721191406, + 0.19881343841552734 + ], + [ + 0.0041904449462890625, + -0.0000057220458984375, + -0.0000057220458984375, + 0.008458137512207031, + -0.0000057220458984375, + -0.0042858123779296875, + -0.0000057220458984375, + 0.002468109130859375, + -0.0000057220458984375, + -0.03716564178466797, + -0.10456657409667969, + -0.0047702789306640625, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + 0.017671585083007812, + 0.0004062652587890625, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.00655364990234375, + 0.001873016357421875, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.04653644561767578, + -0.01836395263671875, + 0.014448165893554688, + -0.0209197998046875, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.006618499755859375, + 0.02408599853515625, + -0.0000057220458984375, + 0.0012884140014648438, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0050144195556640625, + -0.0000057220458984375, + 0.036708831787109375, + -0.0000057220458984375, + -0.0056591033935546875, + -0.0004215240478515625, + 0.0014057159423828125, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.033232688903808594, + -0.0000057220458984375, + -0.0000057220458984375, + -0.008130073547363281, + 0.016930580139160156, + -0.0000057220458984375, + -0.0012025833129882812, + -0.0000057220458984375, + 0.000545501708984375, + -0.0000057220458984375, + -0.0004673004150390625, + -0.0000057220458984375, + 0.0038089752197265625, + -0.008646011352539062, + -0.0000057220458984375, + -0.008909225463867188, + -0.011255264282226562, + -0.0000057220458984375, + 0.0925750732421875, + -0.0064563751220703125, + -0.0000057220458984375, + 0.0011615753173828125, + 0.00002956390380859375, + -0.0000057220458984375, + 0.07063961029052734, + -0.030902862548828125, + -0.0000057220458984375, + -0.0010814666748046875, + 0.00038909912109375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0013380050659179688, + -0.022397994995117188, + 0.027740478515625, + -0.0000057220458984375, + -0.0000057220458984375, + 0.01797771453857422, + 0.009552955627441406, + 0.01857471466064453, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0055103302001953125, + -0.0000057220458984375, + -0.0000057220458984375, + -0.01697063446044922, + -0.0159149169921875, + -0.011240959167480469, + 0.000301361083984375, + 0.020501136779785156, + -0.0000057220458984375, + -0.0006427764892578125, + -0.0000057220458984375, + 0.04800224304199219, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.00001811981201171875, + 0.0005950927734375, + 0.00732421875, + -0.0000057220458984375, + 0.001216888427734375, + 0.00897216796875, + -0.1255035400390625, + 0.001003265380859375, + 0.006274223327636719, + 0.0026502609252929688, + -0.00449371337890625, + -0.0023517608642578125, + -0.0000057220458984375, + -0.0000057220458984375, + -0.6521244049072266, + -0.009072303771972656, + 0.013387680053710938, + -0.0000057220458984375, + -0.022745132446289062, + -0.0000057220458984375, + -0.0000057220458984375, + 0.000606536865234375, + -0.0011501312255859375, + -0.0000057220458984375, + -0.0000057220458984375, + 0.023046493530273438, + -0.0000057220458984375, + -0.008263587951660156, + -0.11597061157226562 + ], + [ + -0.0037221908569335938, + 0.00225830078125, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.05941295623779297, + 0.04140281677246094, + 0.24284648895263672, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.11462688446044922, + 0.012240409851074219, + 0.0012884140014648438, + -0.00001049041748046875, + 0.01781749725341797, + 0.005211830139160156, + -0.0016298294067382812, + -0.2994966506958008, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + 0.26962947845458984, + -0.00001049041748046875, + 0.050202369689941406, + 0.04053211212158203, + -0.30355358123779297, + -0.00001049041748046875, + -0.0013666152954101562, + -0.00001049041748046875, + 0.06442928314208984, + -0.00001049041748046875, + 0.04406547546386719, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00763702392578125, + -0.00001049041748046875, + -0.03402233123779297, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.005751609802246094, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + 0.13191986083984375, + -0.00001049041748046875, + -0.031653404235839844, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.02394390106201172, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.1073293685913086, + 0.20270729064941406, + 0.02746295928955078, + -0.00001049041748046875, + 0.020377159118652344, + -0.31055259704589844, + -0.043480873107910156, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + 0.04507160186767578, + -0.0014734268188476562, + -0.00001049041748046875, + 0.048813819885253906, + -0.00001049041748046875, + -0.00001049041748046875, + -0.1407604217529297, + -0.00001049041748046875, + -0.00001049041748046875, + 0.013548851013183594, + 0.016210556030273438, + -0.00001049041748046875, + -0.011261940002441406, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00029277801513671875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + 0.008993148803710938, + -0.020813941955566406, + -0.00001049041748046875, + -0.00001049041748046875, + -0.008435249328613281, + -0.021961212158203125, + -0.04410362243652344, + 0.1307668685913086, + 0.005297660827636719, + -0.00001049041748046875, + 0.006031990051269531, + 0.016150474548339844, + -0.00001049041748046875, + -0.00001049041748046875, + 0.01802349090576172, + 0.0018205642700195312, + 0.0016574859619140625, + 0.0005712509155273438, + -0.00001049041748046875, + -0.00001049041748046875, + 0.02598094940185547, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + 0.02737903594970703, + 0.039580345153808594, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + 0.09876728057861328, + 0.035803794860839844, + -0.00001049041748046875, + -0.027251243591308594, + -0.00001049041748046875, + -0.07061004638671875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + 0.08719158172607422, + -0.37606334686279297 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Change in logit diff when ablating L5 SAE features for all prompts at pos 10" + }, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Feature Idx" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Prompt Idx" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def ablate_sae_feature(sae_acts, hook, pos, feature_id):\n", + " if pos is None:\n", + " sae_acts[:, :, feature_id] = 0.\n", + " else:\n", + " sae_acts[:, pos, feature_id] = 0.\n", + " return sae_acts\n", + "\n", + "layer = 5\n", + "hooked_encoder = model.acts_to_saes[utils.get_act_name('z', layer)]\n", + "all_live_features = torch.arange(hooked_encoder.cfg.d_sae)[live_feature_union.cpu()]\n", + "\n", + "causal_effects = torch.zeros((len(prompts), all_live_features.shape[0]))\n", + "fid_to_idx = {fid.item(): idx for idx, fid in enumerate(all_live_features)}\n", + "\n", + "\n", + "abl_layer, abl_pos = 5, 10\n", + "for feature_id in tqdm.tqdm(all_live_features):\n", + " feature_id = feature_id.item()\n", + " abl_feature_logits = model.run_with_hooks(\n", + " tokens,\n", + " return_type=\"logits\",\n", + " fwd_hooks=[(utils.get_act_name('z', abl_layer) + \".hook_sae_acts_post\", partial(ablate_sae_feature, pos=abl_pos, feature_id=feature_id))]\n", + " ) # [batch, seq, vocab]\n", + " \n", + " abl_feature_logit_diff = logits_to_ave_logit_diff(abl_feature_logits, answer_tokens, per_prompt=True) # [batch]\n", + " causal_effects[:, fid_to_idx[feature_id]] = abl_feature_logit_diff - original_per_prompt_logit_diff\n", + "\n", + "def able_sae_error(sae_error, hook, pos):\n", + " if pos is None:\n", + " sae_error = 0.\n", + " else:\n", + " sae_error[:, pos, ...] = 0.\n", + " return sae_error\n", + "\n", + "\n", + "abl_error_logits = model.run_with_hooks(\n", + " tokens,\n", + " return_type=\"logits\",\n", + " fwd_hooks=[(utils.get_act_name('z', abl_layer) + \".hook_sae_error\", partial(able_sae_error, pos=abl_pos))]\n", + ") # [batch, seq, vocab]\n", + "\n", + "abl_error_logit_diff = logits_to_ave_logit_diff(abl_error_logits, answer_tokens, per_prompt=True) # [batch]\n", + "error_abl_effect = abl_error_logit_diff - original_per_prompt_logit_diff\n", + "\n", + "\n", + "causal_effects_with_error = torch.cat([causal_effects, error_abl_effect.unsqueeze(-1).cpu()], dim=-1)\n", + "imshow(causal_effects_with_error, title=f\"Change in logit diff when ablating L{abl_layer} SAE features for all prompts at pos {abl_pos}\", xaxis=\"Feature Idx\", yaxis=\"Prompt Idx\", x=list(map(str, all_live_features.tolist()))+[\"error\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that on some prompts, ablating the error term (right most column) does have a non trivial effect on the logit diff, although I don't see a clear pattern. It seems useful to include this term when doing causal interventions to get a better sense of how much the SAE features are actually explaining. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Attribution patching " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "Both [Anthropic](https://transformer-circuits.pub/2024/march-update/index.html#feature-heads) and [Marks et al](https://arxiv.org/abs/2403.19647v2). also demonstrated the use of gradient based attribution techniques as a substitute for activation patching on SAE features. The key idea is that patching / ablations (as we did above) can be slow, as it requires a new forward pass for each patch. This seems especially problematic when dealing with SAEs with tens of thousands of features per activation. They find that gradient based attribution techniques like [attribution patching](https://www.neelnanda.io/mechanistic-interpretability/attribution-patching) are good approximations, allowing for more efficient and scalable circuit analysis with SAEs.\n", + "\n", + "With `HookedSAETransformer`, added SAEs are automatically spliced into the computational graph, allowing us to implement this easily. Let's implement attribution patching for every L5 SAE feature to find causally relevant SAE features with just one forward and one backward pass." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.set_grad_enabled(True)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-7.6294e-06, device='cuda:0', grad_fn=)\n", + "Clean Value: -7.62939453125e-06\n", + "Clean Activations Cached: 1\n", + "Clean Gradients Cached: 1\n" + ] + } + ], + "source": [ + "from transformer_lens import ActivationCache\n", + "filter_sae_acts = lambda name: (\"hook_sae_acts_post\" in name)\n", + "def get_cache_fwd_and_bwd(model, tokens, metric):\n", + " model.reset_hooks()\n", + " cache = {}\n", + " def forward_cache_hook(act, hook):\n", + " cache[hook.name] = act.detach()\n", + " model.add_hook(filter_sae_acts, forward_cache_hook, \"fwd\")\n", + "\n", + " grad_cache = {}\n", + " def backward_cache_hook(act, hook):\n", + " grad_cache[hook.name] = act.detach()\n", + " model.add_hook(filter_sae_acts, backward_cache_hook, \"bwd\")\n", + "\n", + " value = metric(model(tokens))\n", + " print(value)\n", + " value.backward()\n", + " model.reset_hooks()\n", + " return value.item(), ActivationCache(cache, model), ActivationCache(grad_cache, model)\n", + "\n", + "\n", + "BASELINE = original_per_prompt_logit_diff\n", + "def ioi_metric(logits, answer_tokens=answer_tokens):\n", + " return (logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=True) - BASELINE).sum()\n", + "\n", + "clean_tokens = tokens.clone()\n", + "clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(model, clean_tokens, ioi_metric)\n", + "print(\"Clean Value:\", clean_value)\n", + "print(\"Clean Activations Cached:\", len(clean_cache))\n", + "print(\"Clean Gradients Cached:\", len(clean_grad_cache))" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Feature Idx: %{x}
Prompt Idx: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "x": [ + "46", + "345", + "702", + "1372", + "1755", + "1965", + "2457", + "2496", + "2646", + "2999", + "3047", + "4569", + "5132", + "5203", + "5508", + "5940", + "6144", + "6371", + "6515", + "6558", + "6812", + "7092", + "7515", + "7907", + "8063", + "8623", + "8737", + "8768", + "9096", + "9102", + "9186", + "9463", + "9746", + "9913", + "10581", + "10894", + "12109", + "12485", + "12764", + "12866", + "13063", + "13624", + "13707", + "13777", + "14844", + "15050", + "15170", + "15696", + "16178", + "16892", + "17156", + "17259", + "17497", + "17854", + "18043", + "18210", + "18318", + "18385", + "18440", + "18920", + "19183", + "19263", + "19442", + "19524", + "19573", + "20838", + "21151", + "21657", + "22108", + "23578", + "24091", + "24217", + "25792", + "26373", + "26410", + "27535", + "27787", + "27811", + "27960", + "28061", + "28241", + "28242", + "28254", + "28349", + "28977", + "29027", + "29482", + "29603", + "29700", + "29822", + "32177", + "32920", + "33320", + "33730", + "33966", + "34177", + "34334", + "34947", + "35403", + "35425", + "35579", + "35665", + "35815", + "36109", + "36172", + "36451", + "36767", + "36917", + "38570", + "39962", + "40409", + "40418", + "40661", + "41162", + "41185", + "41552", + "42024", + "42161", + "42437", + "42577", + "42882", + "42931", + "43035", + "43414", + "43643", + "43662", + "44203", + "44256", + "44452", + "44652", + "45179", + "45814", + "45984", + "46880", + "47117", + "47170", + "47231", + "47313", + "47680", + "48063", + "48703" + ], + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.001567811705172062, + 0, + 0, + 0.001697835512459278, + 0.00011560246639419347, + 0, + 0, + -0.0002851475146599114, + 0, + -0.030827227979898453, + -0.06409652531147003, + 0, + 0.00015167289529927075, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.013627146370708942, + -0.004393726587295532, + 0.0015328703448176384, + -0.0038613511715084314, + 0, + 0, + 0, + 0, + -0.02049136720597744, + 0, + 0, + -0.007114107254892588, + 0.0003477374848444015, + -0.001384311355650425, + 0.003183899214491248, + 0.0004558839718811214, + -0.059277813881635666, + -0.0035793157294392586, + -0.00589390005916357, + 0, + -0.0001910730206873268, + 0, + 0.0006608504336327314, + -0.0004212319909129292, + 0, + -0.003545185085386038, + -0.00327106611803174, + 0, + 0, + 0.0040074847638607025, + 0, + 0, + 0, + -0.0026069351006299257, + 0, + 0, + 0, + -0.00008433026232523844, + -0.00018646706303115934, + 0, + -0.00439279293641448, + -0.013254894874989986, + 0, + 0.050094299018383026, + -0.021308520808815956, + 0, + -0.0006410681526176631, + 0, + 0, + 0.02329532988369465, + -0.05166983604431152, + -0.002982117934152484, + -0.000014124364497547504, + -0.0020334068685770035, + 0, + 0, + 0, + -0.02020590752363205, + 0.00998645182698965, + 0, + -0.004585121292620897, + 0.005916096270084381, + 0.0018219061894342303, + 0.005700498353689909, + 0.0008085825829766691, + 0, + 0, + -0.0032405084930360317, + 0, + 0, + 0, + -0.014961971901357174, + 0, + 0, + 0, + 0, + -0.016915086656808853, + 0, + 0.016825370490550995, + 0, + -0.00311169121414423, + 0, + 0.005266942549496889, + 0, + 0, + 0, + -0.009660078212618828, + 0.0010975055629387498, + 0.006078756880015135, + 0, + 0, + 0.003166533075273037, + -0.044512320309877396, + 0.0002630578528624028, + 0, + 0, + -0.00025422731414437294, + 0, + 0, + 0, + -0.3718416392803192, + -0.0008081833366304636, + 0.00043700754758901894, + 0, + -0.023154418915510178, + 0.00004691413778346032, + 0, + 0, + -0.0002914638607762754, + 0.0006733346963301301, + 0, + 0.008972969837486744, + 0, + -0.008168808184564114 + ], + [ + 0, + 0, + -0.0006953283445909619, + 0, + 0, + 0, + -0.001286927843466401, + -0.017273705452680588, + 0, + 0.05898163467645645, + 0.013462062925100327, + 0, + 0, + -0.00003325308352941647, + -0.0027551515959203243, + -0.004652985371649265, + 0, + 0, + 0, + 0, + 0, + 0, + -0.21421866118907928, + 0, + 0, + 0.002191215055063367, + 0, + 0.07645706832408905, + 0.0052618952468037605, + 0, + 0, + -0.020269982516765594, + 0, + 0.013446477241814137, + 0.0068704248405992985, + 0.08710267394781113, + 0, + 0, + 0, + 0, + 0.028982989490032196, + 0, + 0, + 0.014961526729166508, + 0, + 0, + 0, + 0, + 0.011233230121433735, + 0, + -0.009112805128097534, + 0.003226917004212737, + 0, + 0, + 0, + 0.112985759973526, + 0, + 0.028253009542822838, + 0, + 0, + 0, + 0.0009787877788767219, + 0, + 0, + 0, + 0, + -0.03986968472599983, + 0, + 0, + -0.006135094445198774, + 0.04977395758032799, + 0.0397123359143734, + 0.027974072843790054, + 0, + -0.00044811973930336535, + -0.10083132237195969, + 0.000008234118467953522, + 0.06165996566414833, + 0, + 0, + 0.021058127284049988, + 0, + 0, + 0, + 0, + 0, + -0.08074336498975754, + -0.009298793971538544, + 0, + 0.012482613325119019, + 0.06513619422912598, + 0, + 0, + 0.00029019018984399736, + 0, + 0, + 0.0014882637187838554, + 0, + 0, + 0, + -0.004803473129868507, + 0, + 0, + 0.025678949430584908, + 0, + -0.04240157827734947, + 0, + 0, + 0.0015190609265118837, + 0.0006482255994342268, + 0.03654245659708977, + 0, + 0, + 0, + 0, + 0, + 0.0020186977926641703, + 0, + 0, + 0.17831696569919586, + 0, + 0, + 0, + 0.0005887048901058733, + 0.012331255711615086, + 0, + 0, + 0.11619613319635391, + 0.04687207192182541, + 0.03033648431301117, + 0, + -0.004195880610495806, + 0.00006391256465576589, + -0.03162289038300514, + 0, + 0, + 0, + 0, + 0, + 0, + 0.03672636300325394 + ], + [ + 0.00788492988795042, + 0, + 0, + 0.003685369621962309, + 0, + 0, + 0, + -0.010384900495409966, + 0, + -0.1327948272228241, + -0.22788244485855103, + 0, + 0.0003893508983310312, + 0, + 0, + 0, + 0.009530982933938503, + 0, + -0.0001286355109186843, + 0, + 0, + 0.0001596187794348225, + 0.011789986863732338, + -0.0022452236153185368, + 0, + -0.0014552043285220861, + 0.0002036036894423887, + -0.03003234602510929, + -0.036742936819791794, + -0.028862446546554565, + -0.003727517556399107, + 0, + 0.0011460097739472985, + 0, + -0.027142589911818504, + -0.054151974618434906, + -0.0004727205669041723, + 0, + -0.006094601005315781, + 0.00013960858632344753, + 0, + -0.0003665595140773803, + 0.00028091753483749926, + -0.17846877872943878, + -0.004990901332348585, + -0.010615025646984577, + 0, + 0.015916047617793083, + 0, + 0.0008773574372753501, + 0.004459311719983816, + 0, + -0.015235064551234245, + 0, + -0.0008741968194954097, + -0.04074608162045479, + 0.007227533031255007, + 0, + 0, + 0, + -0.007763775996863842, + 0, + 0, + -0.0011336231837049127, + 0, + -0.004542750306427479, + 0.016146792098879814, + -0.032868705689907074, + -0.013282506726682186, + 0, + 0.1884474903345108, + -0.07819699496030807, + 0, + 0.00013099861098453403, + 0.00024322106037288904, + 0, + 0.04764547944068909, + -0.09056885540485382, + 0, + -0.005007788073271513, + 0.000487087934743613, + 0, + 0, + 0, + -0.07196655869483948, + 0.007451012264937162, + 0, + -0.013892672955989838, + -0.005596193019300699, + -0.005349555052816868, + -0.00015437132969964296, + 0, + 0, + 0, + -0.00894666463136673, + 0, + 0, + 0, + -0.036862581968307495, + 0, + 0, + 0, + 0, + -0.026162482798099518, + 0, + 0.046491872519254684, + 0, + -0.030160455033183098, + 0, + -0.009029642678797245, + -0.0021479984279721975, + -0.0005375721957534552, + -0.002135993679985404, + -0.027962258085608482, + 0, + 0.0008057129452936351, + 0, + 0, + 0, + -0.26795026659965515, + 0, + 0, + 0, + 0, + 0, + -0.0027670287527143955, + -0.0002252299600513652, + -0.7548060417175293, + 0, + -0.05009680241346359, + 0, + -0.03914204612374306, + 0, + 0, + 0.016279445961117744, + 0, + 0, + 0, + 0.025662390515208244, + -0.000049459828005637974, + -0.0023572721984237432 + ], + [ + 0, + 0, + 0.0009027881897054613, + 0, + 0, + 0, + 0, + -0.01007400918751955, + 0.07334298640489578, + 0.15174342691898346, + 0, + 0, + 0, + 0, + 0, + 0.0007311829249374568, + 0, + 0, + 0, + 0.011839455924928188, + 0, + 0, + -0.2282165139913559, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.017542533576488495, + 0, + 0, + 0, + 0.1636323779821396, + 0, + 0.10289037227630615, + 0, + 0, + 0, + 0, + 0, + 0.024433566257357597, + 0, + 0, + 0, + 0, + 0.013018166646361351, + 0, + -0.0005916667287237942, + 0, + 0, + 0, + 0, + 0.07111621648073196, + 0.0004984873230569065, + 0.015917964279651642, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.06262800097465515, + 0.17253385484218597, + 0, + 0, + -0.0007970984443090856, + -0.1451263427734375, + 0, + 0.08718064427375793, + 0, + 0, + 0.007446629460901022, + 0, + 0, + 0, + 0, + 0, + -0.09546831995248795, + 0, + 0, + 0.06110787391662598, + 0.08931172639131546, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.005256101489067078, + -0.00553735950961709, + 0, + 0, + 0.006732907146215439, + 0, + -0.005547903478145599, + 0, + 0, + 0, + 0.01766844280064106, + -0.0034187675919383764, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.1122211441397667, + 0, + 0, + 0, + 0.009442206472158432, + 0, + 0, + 0, + 0, + 0, + 0.00800288561731577, + 0, + 0.006613056641072035, + 0, + -0.06462590396404266, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + -0.0009047402418218553, + 0, + 0, + -0.0005877931835129857, + 0, + 0, + 0, + -0.0004729636711999774, + 0, + -0.05036322772502899, + -0.24687804281711578, + 0, + 0.001115482416935265, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.0024291854351758957, + 0, + 0, + 0, + -0.029154174029827118, + -0.011211197823286057, + 0, + 0, + 0, + 0, + 0, + -0.0075091151520609856, + 0, + 0.0037634933833032846, + 0.022711526602506638, + 0, + 0, + -0.00011145337339257821, + 0, + -0.08350298553705215, + 0, + 0, + 0, + 0, + 0, + 0.0063380529172718525, + 0, + 0, + 0, + 0, + 0, + 0, + -0.0010615212377160788, + 0, + 0, + 0, + 0.001314864493906498, + 0, + 0, + -0.0020079570822417736, + 0, + 0, + 0, + -0.095857173204422, + 0, + 0, + 0.04977884888648987, + 0.04924672096967697, + 0, + 0.00675918348133564, + 0, + 0, + 0.02823697216808796, + -0.07869893312454224, + 0, + 0, + -0.00039145027403719723, + 0, + 0, + 0, + -0.03502006456255913, + 0, + 0, + -0.004709419794380665, + -0.007543480955064297, + -0.007213911972939968, + 0.0026987697929143906, + 0, + 0, + 0, + -0.0016787010245025158, + 0, + 0, + -0.002866228111088276, + -0.04759479686617851, + 0, + 0, + 0, + 0, + -0.005348640959709883, + 0, + 0.17661413550376892, + 0, + -0.0024743194226175547, + 0, + 0.0269751138985157, + 0, + 0, + 0, + -0.025461290031671524, + 0, + 0, + 0, + 0, + 0, + -0.14607883989810944, + 0, + 0.020490022376179695, + 0.007573024369776249, + 0, + 0, + 0, + 0, + -0.8939738869667053, + 0, + -0.006900197826325893, + -0.0031849159859120846, + -0.015817783772945404, + 0, + 0, + 0, + 0, + 0, + 0.00032859406201168895, + 0.11629504710435867, + 0, + 0 + ], + [ + 0, + 0, + 0.0020032059401273727, + 0, + 0, + 0, + 0, + -0.02256190776824951, + 0.07616151124238968, + 0.3106333911418915, + 0, + 0, + 0, + 0, + 0, + -0.014044971205294132, + 0, + 0, + 0, + 0, + 0, + 0, + -0.3483165502548218, + 0, + 0, + 0, + 0, + 0.05930393189191818, + 0, + 0, + 0.004992437083274126, + -0.08404884487390518, + 0, + 0, + 0, + 0.16281214356422424, + 0, + 0.28443410992622375, + 0, + 0, + 0.0014393558958545327, + 0, + 0, + -0.009063852950930595, + 0, + 0, + 0, + 0, + 0.001169737195596099, + 0, + 0, + 0, + 0, + 0, + 0, + 0.06898342072963715, + 0, + 0.007991905324161053, + 0, + 0.006260615773499012, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.037955716252326965, + 0.3505173921585083, + 0, + 0, + 0, + -0.338177889585495, + 0, + 0.158599853515625, + 0, + 0, + 0.01131439208984375, + 0, + 0.006751265376806259, + 0, + 0, + 0, + -0.04573351889848709, + 0, + 0, + 0.04386100172996521, + 0.11277603358030319, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.0003205372195225209, + 0, + 0, + 0, + -0.005409737583249807, + -0.009204162284731865, + 0, + 0, + 0, + 0.004804544150829315, + -0.005810749251395464, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.09645535796880722, + 0, + 0, + 0, + 0.032931435853242874, + 0, + 0, + 0, + 0, + 0.028524864464998245, + 0.09402520954608917, + 0, + 0.008998546749353409, + 0, + -0.03251685947179794, + 0, + 0, + 0, + 0, + 0, + 0, + 0.037343256175518036 + ], + [ + 0.0045840502716600895, + 0, + 0, + 0.009021798148751259, + 0, + -0.004217533860355616, + 0, + 0.0025705555453896523, + 0, + -0.035309672355651855, + -0.09942735731601715, + -0.004700342193245888, + 0, + 0, + 0, + 0.018288278952240944, + 0.0004021169152110815, + 0, + 0, + 0, + 0, + -0.005593586713075638, + 0.0018821493722498417, + 0, + 0, + 0, + 0, + -0.04561242088675499, + -0.01815006509423256, + 0.016583485528826714, + -0.020843051373958588, + 0, + 0, + 0, + -0.006372869946062565, + 0.04272369295358658, + 0, + 0.0013309348141774535, + 0, + 0, + 0, + -0.0031638317741453648, + 0, + 0.08714215457439423, + 0, + -0.005442100111395121, + -0.00039313771412707865, + 0.0014464370906352997, + 0, + 0, + 0, + 0, + -0.03132649511098862, + 0, + 0, + -0.007972904480993748, + 0.01753396727144718, + 0, + -0.0011563192820176482, + 0, + 0.0017362519865855575, + 0, + -0.0004587600124068558, + 0, + 0.0038881096988916397, + -0.008516360074281693, + 0, + -0.008183307014405727, + -0.010095844976603985, + 0, + 0.10722006857395172, + -0.002898464212194085, + 0, + 0.0012827662285417318, + 0.00004252225699019618, + 0, + 0.07567721605300903, + -0.030121177434921265, + 0, + -0.0010666534071788192, + 0.0006539365276694298, + 0, + 0, + -0.0011567147448658943, + -0.021622339263558388, + 0.028687214478850365, + 0, + 0, + 0.018764594569802284, + 0.010613140650093555, + 0.019510075449943542, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.005288010463118553, + 0, + 0, + -0.016743114218115807, + -0.015873711556196213, + -0.009877816773951054, + 0.0003150522243231535, + 0.023689158260822296, + 0, + -0.00033418016391806304, + 0, + 0.04904749244451523, + 0, + 0, + 0, + 0.0006500506424345076, + 0.000622213410679251, + 0.00738720316439867, + 0, + 0.0012243357487022877, + 0.009066173806786537, + -0.12073952704668045, + 0.0010678119724616408, + 0.006296947598457336, + 0.002682592486962676, + -0.00444818427786231, + -0.0023324599023908377, + 0, + 0, + -0.5609893798828125, + -0.008780602365732193, + 0.015986066311597824, + 0, + -0.02213476411998272, + 0, + 0, + 0.0006705078994855285, + -0.0011221399763599038, + 0, + 0, + 0.025299811735749245, + 0, + -0.008218510076403618 + ], + [ + -0.0034782839938998222, + 0.0022423912305384874, + 0, + 0, + 0, + 0, + 0, + -0.05859537422657013, + 0.0421387143433094, + 0.26256099343299866, + 0, + 0, + 0, + 0, + 0, + -0.10330676287412643, + 0.012355834245681763, + 0.0013472040882334113, + 0, + 0.019914263859391212, + 0.005261276848614216, + 0.001149827498011291, + -0.03320133313536644, + 0, + 0, + 0, + 0, + 0.32198745012283325, + 0, + 0.05401667580008507, + 0.04610951617360115, + -0.2326284795999527, + 0, + 0.0000856258993735537, + 0, + 0.074106365442276, + 0, + 0.044469863176345825, + 0, + 0, + 0, + -0.006453251000493765, + 0, + -0.018431225791573524, + 0, + 0, + 0, + 0, + -0.005704954732209444, + 0, + 0, + 0, + 0, + 0, + 0, + 0.13457728922367096, + 0, + -0.029186677187681198, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.022995056584477425, + 0, + 0, + 0, + -0.09004921466112137, + 0.24257110059261322, + 0.02852930873632431, + 0, + 0.021270141005516052, + -0.13564155995845795, + -0.03098711557686329, + 0, + 0, + 0, + 0.0486220121383667, + -0.001395023544318974, + 0, + 0.04929636791348457, + 0, + 0, + -0.13068373501300812, + 0, + 0, + 0.016955919563770294, + 0.03848254308104515, + 0, + -0.011160435155034065, + 0, + 0, + -0.0002991429646499455, + 0, + 0, + 0, + 0.01138608530163765, + -0.020150866359472275, + 0, + 0, + -0.007353566121309996, + -0.021389631554484367, + -0.042083244770765305, + 0.13586723804473877, + 0.005315479822456837, + 0, + 0.008157049305737019, + 0.022239860147237778, + 0, + 0, + 0.01896926946938038, + 0.0018052944215014577, + 0.0016496418975293636, + 0.0005593635141849518, + 0, + 0, + 0.07655386626720428, + 0, + 0, + 0, + 0.02781328558921814, + 0.04012482985854149, + 0, + 0, + 0, + 0.10631410032510757, + 0.03608629107475281, + 0, + -0.02651066705584526, + 0, + -0.0690990686416626, + 0, + 0, + 0, + 0, + 0, + 0, + 0.1022648885846138 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "attribution patching" + }, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Feature Idx" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Prompt Idx" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def attr_patch_sae_acts(\n", + " clean_cache: ActivationCache, \n", + " clean_grad_cache: ActivationCache,\n", + " site: str, layer: int\n", + " ):\n", + " clean_sae_acts_post = clean_cache[utils.get_act_name(site, layer) + \".hook_sae_acts_post\"] \n", + " clean_grad_sae_acts_post = clean_grad_cache[utils.get_act_name(site, layer) + \".hook_sae_acts_post\"] \n", + " sae_act_attr = clean_grad_sae_acts_post * (0 - clean_sae_acts_post)\n", + " return sae_act_attr\n", + "\n", + "site = \"z\"\n", + "layer = 5\n", + "sae_act_attr = attr_patch_sae_acts(clean_cache, clean_grad_cache, site, layer)\n", + "\n", + "imshow(\n", + " sae_act_attr[:, s2_pos, all_live_features],\n", + " title=\"attribution patching\",\n", + " xaxis=\"Feature Idx\", yaxis=\"Prompt Idx\", x=list(map(str, all_live_features.tolist())))" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "hovertemplate": "Activation Patch=%{x}
Attribution Patch=%{y}", + "legendgroup": "", + "marker": { + "color": "#636efa", + "symbol": "circle" + }, + "mode": "markers", + "name": "", + "showlegend": false, + "type": "scattergl", + "x": [ + 0.0012617111206054688, + -9.5367431640625e-7, + -9.5367431640625e-7, + 0.0016908645629882812, + -0.0002231597900390625, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.00029659271240234375, + -9.5367431640625e-7, + -0.03279590606689453, + -0.07254886627197266, + -9.5367431640625e-7, + 0.00013065338134765625, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.014922142028808594, + -0.0044403076171875, + 0.0007047653198242188, + -0.00428009033203125, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.039069175720214844, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.007334709167480469, + 0.00033092498779296875, + -0.0017004013061523438, + 0.0026845932006835938, + 0.00043010711669921875, + -0.11128997802734375, + -0.0038976669311523438, + -0.006033897399902344, + -9.5367431640625e-7, + -0.00027751922607421875, + -9.5367431640625e-7, + 0.0006570816040039062, + -0.0004291534423828125, + -9.5367431640625e-7, + -0.0035734176635742188, + -0.0033063888549804688, + -9.5367431640625e-7, + -9.5367431640625e-7, + 0.0033960342407226562, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.0030546188354492188, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.0000972747802734375, + -0.0001811981201171875, + -9.5367431640625e-7, + -0.004569053649902344, + -0.013583183288574219, + -9.5367431640625e-7, + 0.02047252655029297, + -0.02572154998779297, + -9.5367431640625e-7, + -0.0006608963012695312, + -9.5367431640625e-7, + -9.5367431640625e-7, + 0.02255725860595703, + -0.05519580841064453, + -0.0033473968505859375, + -0.0000057220458984375, + -0.0026073455810546875, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.02097320556640625, + 0.008440971374511719, + -9.5367431640625e-7, + -0.004597663879394531, + 0.00159454345703125, + 0.0001544952392578125, + 0.005199432373046875, + 0.0007762908935546875, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.0032625198364257812, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.015192985534667969, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.018138885498046875, + -9.5367431640625e-7, + 0.010298728942871094, + -9.5367431640625e-7, + -0.0031423568725585938, + -9.5367431640625e-7, + 0.004242897033691406, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.010041236877441406, + 0.0010347366333007812, + 0.006011962890625, + -9.5367431640625e-7, + -9.5367431640625e-7, + 0.00301361083984375, + -0.04584026336669922, + 0.0002079010009765625, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.0002574920654296875, + -9.5367431640625e-7, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.45942211151123047, + -0.0008325576782226562, + 0.00041484832763671875, + -9.5367431640625e-7, + -0.023777008056640625, + 0.0000514984130859375, + -9.5367431640625e-7, + -9.5367431640625e-7, + -0.00030422210693359375, + 0.0006666183471679688, + -9.5367431640625e-7, + 0.004633903503417969, + -9.5367431640625e-7, + -0.008234977722167969, + 0.000003814697265625, + 0.000003814697265625, + -0.00208282470703125, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + -0.0012912750244140625, + -0.01760101318359375, + 0.000003814697265625, + 0.057277679443359375, + 0.013429641723632812, + 0.000003814697265625, + 0.000003814697265625, + -0.0000457763671875, + -0.0027828216552734375, + -0.0055084228515625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + -0.2744255065917969, + 0.000003814697265625, + 0.000003814697265625, + 0.0021514892578125, + 0.000003814697265625, + 0.06994247436523438, + 0.0048542022705078125, + 0.000003814697265625, + 0.000003814697265625, + -0.0567169189453125, + 0.000003814697265625, + 0.012315750122070312, + 0.0066585540771484375, + 0.07937240600585938, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.028867721557617188, + 0.000003814697265625, + 0.000003814697265625, + 0.0074901580810546875, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.009624481201171875, + 0.000003814697265625, + -0.009510040283203125, + 0.0032100677490234375, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.10918617248535156, + 0.000003814697265625, + 0.026102066040039062, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000946044921875, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + -0.041675567626953125, + 0.000003814697265625, + 0.000003814697265625, + -0.0066776275634765625, + 0.03926849365234375, + 0.03615379333496094, + 0.027612686157226562, + 0.000003814697265625, + -0.0004673004150390625, + -0.1435985565185547, + -0.00030517578125, + 0.059326171875, + 0.000003814697265625, + 0.000003814697265625, + 0.020435333251953125, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + -0.11923980712890625, + -0.009393692016601562, + 0.000003814697265625, + 0.011783599853515625, + 0.06122589111328125, + 0.000003814697265625, + 0.000003814697265625, + 0.0002918243408203125, + 0.000003814697265625, + 0.000003814697265625, + 0.001491546630859375, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + -0.0050716400146484375, + 0.000003814697265625, + 0.000003814697265625, + 0.025064468383789062, + 0.000003814697265625, + -0.0467529296875, + 0.000003814697265625, + 0.000003814697265625, + 0.0014934539794921875, + 0.00043487548828125, + 0.028188705444335938, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.001995086669921875, + 0.000003814697265625, + 0.000003814697265625, + 0.13014602661132812, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.0005893707275390625, + 0.012182235717773438, + 0.000003814697265625, + 0.000003814697265625, + 0.11103057861328125, + 0.042850494384765625, + 0.030099868774414062, + 0.000003814697265625, + -0.0047321319580078125, + 0.0000133514404296875, + -0.0320587158203125, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.000003814697265625, + 0.031030654907226562, + 0.007018089294433594, + 0, + 0, + 0.0028057098388671875, + 0, + 0, + 0, + -0.010999679565429688, + 0, + -0.1419973373413086, + -0.24188613891601562, + 0, + 0.0003147125244140625, + 0, + 0, + 0, + 0.009432792663574219, + 0, + -0.000125885009765625, + 0, + 0, + 0.00017070770263671875, + 0.011651992797851562, + -0.00225830078125, + 0, + -0.0014581680297851562, + 0.00020122528076171875, + -0.030771255493164062, + -0.03744316101074219, + -0.034499168395996094, + -0.00374603271484375, + 0, + 0.0011348724365234375, + 0, + -0.0302276611328125, + -0.08229637145996094, + -0.00048160552978515625, + 0, + -0.00640869140625, + 0.0001277923583984375, + 0, + -0.0008974075317382812, + 0.00022983551025390625, + -0.2322559356689453, + -0.0050449371337890625, + -0.010677337646484375, + 0, + 0.014942169189453125, + 0, + 0.0008764266967773438, + 0.00417327880859375, + 0, + -0.015301704406738281, + 0, + -0.0008974075317382812, + -0.04426097869873047, + 0.005242347717285156, + 0, + 0, + 0, + -0.009447097778320312, + 0, + 0, + -0.0011806488037109375, + 0, + -0.0045909881591796875, + 0.015285491943359375, + -0.034976959228515625, + -0.013401985168457031, + 0, + 0.1357421875, + -0.09111690521240234, + 0, + 0.00013065338134765625, + 0.0002460479736328125, + 0, + 0.04656982421875, + -0.09346866607666016, + 0, + -0.005030632019042969, + 0.0001125335693359375, + 0, + 0, + 0, + -0.07491683959960938, + 0.006598472595214844, + 0, + -0.014060020446777344, + -0.008306503295898438, + -0.0054874420166015625, + -0.0004930496215820312, + 0, + 0, + 0, + -0.008953094482421875, + 0, + 0, + 0, + -0.03713417053222656, + 0, + 0, + 0, + 0, + -0.028200149536132812, + 0, + 0.036255836486816406, + 0, + -0.03178215026855469, + 0, + -0.012192726135253906, + -0.002147674560546875, + -0.0005474090576171875, + -0.0021409988403320312, + -0.030725479125976562, + 0, + 0.0008029937744140625, + 0, + 0, + 0, + -0.29135894775390625, + 0, + 0, + 0, + 0, + 0, + -0.0027914047241210938, + -0.00022125244140625, + -0.8653240203857422, + 0, + -0.05593109130859375, + 0, + -0.04123210906982422, + 0, + 0, + 0.015351295471191406, + 0, + 0, + 0, + 0.018423080444335938, + -0.0000476837158203125, + -0.0023584365844726562, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.0001983642578125, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + -0.010341644287109375, + 0.07198715209960938, + 0.14725303649902344, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.0002918243408203125, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.011704444885253906, + 9.5367431640625e-7, + 9.5367431640625e-7, + -0.3150959014892578, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + -0.039947509765625, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.15607547760009766, + 9.5367431640625e-7, + 0.09917640686035156, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.019521713256835938, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.012205123901367188, + 9.5367431640625e-7, + -0.0005893707275390625, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.07062149047851562, + 0.000492095947265625, + 0.014776229858398438, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.0557098388671875, + 0.15409469604492188, + 9.5367431640625e-7, + 9.5367431640625e-7, + -0.0007076263427734375, + -0.24256324768066406, + 9.5367431640625e-7, + 0.0858917236328125, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.007343292236328125, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + -0.11646080017089844, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.05528736114501953, + 0.0847921371459961, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.00428009033203125, + -0.0056171417236328125, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.0066967010498046875, + 9.5367431640625e-7, + -0.006005287170410156, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.01735687255859375, + -0.0037336349487304688, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.09533309936523438, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.009324073791503906, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 0.007989883422851562, + 9.5367431640625e-7, + 0.0064525604248046875, + 9.5367431640625e-7, + -0.06574440002441406, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + 9.5367431640625e-7, + -0.0009012222290039062, + 0.00001239776611328125, + 0.00001239776611328125, + -0.0006313323974609375, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.000461578369140625, + 0.00001239776611328125, + -0.055993080139160156, + -0.24974536895751953, + 0.00001239776611328125, + 0.0011262893676757812, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.0025796890258789062, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.030013084411621094, + -0.012925148010253906, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.0253448486328125, + 0.00001239776611328125, + 0.0012464523315429688, + 0.021536827087402344, + 0.00001239776611328125, + 0.00001239776611328125, + -0.00009822845458984375, + 0.00001239776611328125, + -0.09924793243408203, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.006188392639160156, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.0010576248168945312, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.0008172988891601562, + 0.00001239776611328125, + 0.00001239776611328125, + -0.0020704269409179688, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.09985160827636719, + 0.00001239776611328125, + 0.00001239776611328125, + 0.036945343017578125, + 0.025011062622070312, + 0.00001239776611328125, + 0.004599571228027344, + 0.00001239776611328125, + 0.00001239776611328125, + 0.027939796447753906, + -0.07974910736083984, + 0.00001239776611328125, + 0.00001239776611328125, + -0.00038242340087890625, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.035175323486328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.0047245025634765625, + -0.008166313171386719, + -0.008578300476074219, + 0.0018529891967773438, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.0016679763793945312, + 0.00001239776611328125, + 0.00001239776611328125, + -0.0028676986694335938, + -0.04880046844482422, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.0053462982177734375, + 0.00001239776611328125, + 0.1658468246459961, + 0.00001239776611328125, + -0.0024824142456054688, + 0.00001239776611328125, + 0.025139808654785156, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.027915000915527344, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.14544200897216797, + 0.00001239776611328125, + 0.020270347595214844, + 0.007473945617675781, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + -0.8424196243286133, + 0.00001239776611328125, + -0.007409095764160156, + -0.00318145751953125, + -0.015982627868652344, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00001239776611328125, + 0.00034046173095703125, + 0.10727787017822266, + 0.00001239776611328125, + 0.00001239776611328125, + -0.00000762939453125, + -0.00000762939453125, + 0.0019397735595703125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.022940635681152344, + 0.07428932189941406, + 0.29994869232177734, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.016974449157714844, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.4772310256958008, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.05463600158691406, + -0.00000762939453125, + -0.00000762939453125, + 0.004734992980957031, + -0.12352275848388672, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.15236186981201172, + -0.00000762939453125, + 0.27855396270751953, + -0.00000762939453125, + -0.00000762939453125, + 0.001430511474609375, + -0.00000762939453125, + -0.00000762939453125, + -0.016387939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.0008668899536132812, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.06814861297607422, + -0.00000762939453125, + 0.00351715087890625, + -0.00000762939453125, + 0.0061588287353515625, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.02480602264404297, + 0.31668567657470703, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.4413900375366211, + -0.00000762939453125, + 0.1517791748046875, + -0.00000762939453125, + -0.00000762939453125, + 0.010898590087890625, + -0.00000762939453125, + 0.006583213806152344, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.04946422576904297, + -0.00000762939453125, + -0.00000762939453125, + 0.040429115295410156, + 0.1020956039428711, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.0008649826049804688, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.0054836273193359375, + -0.010519981384277344, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.004588127136230469, + -0.006558418273925781, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.07750797271728516, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.03235149383544922, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.02773571014404297, + 0.08978557586669922, + -0.00000762939453125, + 0.008780479431152344, + -0.00000762939453125, + -0.0327301025390625, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + -0.00000762939453125, + 0.035370826721191406, + 0.0041904449462890625, + -0.0000057220458984375, + -0.0000057220458984375, + 0.008458137512207031, + -0.0000057220458984375, + -0.0042858123779296875, + -0.0000057220458984375, + 0.002468109130859375, + -0.0000057220458984375, + -0.03716564178466797, + -0.10456657409667969, + -0.0047702789306640625, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + 0.017671585083007812, + 0.0004062652587890625, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.00655364990234375, + 0.001873016357421875, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.04653644561767578, + -0.01836395263671875, + 0.014448165893554688, + -0.0209197998046875, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.006618499755859375, + 0.02408599853515625, + -0.0000057220458984375, + 0.0012884140014648438, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0050144195556640625, + -0.0000057220458984375, + 0.036708831787109375, + -0.0000057220458984375, + -0.0056591033935546875, + -0.0004215240478515625, + 0.0014057159423828125, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.033232688903808594, + -0.0000057220458984375, + -0.0000057220458984375, + -0.008130073547363281, + 0.016930580139160156, + -0.0000057220458984375, + -0.0012025833129882812, + -0.0000057220458984375, + 0.000545501708984375, + -0.0000057220458984375, + -0.0004673004150390625, + -0.0000057220458984375, + 0.0038089752197265625, + -0.008646011352539062, + -0.0000057220458984375, + -0.008909225463867188, + -0.011255264282226562, + -0.0000057220458984375, + 0.0925750732421875, + -0.0064563751220703125, + -0.0000057220458984375, + 0.0011615753173828125, + 0.00002956390380859375, + -0.0000057220458984375, + 0.07063961029052734, + -0.030902862548828125, + -0.0000057220458984375, + -0.0010814666748046875, + 0.00038909912109375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0013380050659179688, + -0.022397994995117188, + 0.027740478515625, + -0.0000057220458984375, + -0.0000057220458984375, + 0.01797771453857422, + 0.009552955627441406, + 0.01857471466064453, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0055103302001953125, + -0.0000057220458984375, + -0.0000057220458984375, + -0.01697063446044922, + -0.0159149169921875, + -0.011240959167480469, + 0.000301361083984375, + 0.020501136779785156, + -0.0000057220458984375, + -0.0006427764892578125, + -0.0000057220458984375, + 0.04800224304199219, + -0.0000057220458984375, + -0.0000057220458984375, + -0.0000057220458984375, + -0.00001811981201171875, + 0.0005950927734375, + 0.00732421875, + -0.0000057220458984375, + 0.001216888427734375, + 0.00897216796875, + -0.1255035400390625, + 0.001003265380859375, + 0.006274223327636719, + 0.0026502609252929688, + -0.00449371337890625, + -0.0023517608642578125, + -0.0000057220458984375, + -0.0000057220458984375, + -0.6521244049072266, + -0.009072303771972656, + 0.013387680053710938, + -0.0000057220458984375, + -0.022745132446289062, + -0.0000057220458984375, + -0.0000057220458984375, + 0.000606536865234375, + -0.0011501312255859375, + -0.0000057220458984375, + -0.0000057220458984375, + 0.023046493530273438, + -0.0000057220458984375, + -0.008263587951660156, + -0.0037221908569335938, + 0.00225830078125, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.05941295623779297, + 0.04140281677246094, + 0.24284648895263672, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.11462688446044922, + 0.012240409851074219, + 0.0012884140014648438, + -0.00001049041748046875, + 0.01781749725341797, + 0.005211830139160156, + -0.0016298294067382812, + -0.2994966506958008, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + 0.26962947845458984, + -0.00001049041748046875, + 0.050202369689941406, + 0.04053211212158203, + -0.30355358123779297, + -0.00001049041748046875, + -0.0013666152954101562, + -0.00001049041748046875, + 0.06442928314208984, + -0.00001049041748046875, + 0.04406547546386719, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00763702392578125, + -0.00001049041748046875, + -0.03402233123779297, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.005751609802246094, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + 0.13191986083984375, + -0.00001049041748046875, + -0.031653404235839844, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.02394390106201172, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.1073293685913086, + 0.20270729064941406, + 0.02746295928955078, + -0.00001049041748046875, + 0.020377159118652344, + -0.31055259704589844, + -0.043480873107910156, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + 0.04507160186767578, + -0.0014734268188476562, + -0.00001049041748046875, + 0.048813819885253906, + -0.00001049041748046875, + -0.00001049041748046875, + -0.1407604217529297, + -0.00001049041748046875, + -0.00001049041748046875, + 0.013548851013183594, + 0.016210556030273438, + -0.00001049041748046875, + -0.011261940002441406, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00029277801513671875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + 0.008993148803710938, + -0.020813941955566406, + -0.00001049041748046875, + -0.00001049041748046875, + -0.008435249328613281, + -0.021961212158203125, + -0.04410362243652344, + 0.1307668685913086, + 0.005297660827636719, + -0.00001049041748046875, + 0.006031990051269531, + 0.016150474548339844, + -0.00001049041748046875, + -0.00001049041748046875, + 0.01802349090576172, + 0.0018205642700195312, + 0.0016574859619140625, + 0.0005712509155273438, + -0.00001049041748046875, + -0.00001049041748046875, + 0.02598094940185547, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + 0.02737903594970703, + 0.039580345153808594, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + 0.09876728057861328, + 0.035803794860839844, + -0.00001049041748046875, + -0.027251243591308594, + -0.00001049041748046875, + -0.07061004638671875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + -0.00001049041748046875, + 0.08719158172607422 + ], + "xaxis": "x", + "y": [ + 0.001567811705172062, + 0, + 0, + 0.001697835512459278, + 0.00011560246639419347, + 0, + 0, + -0.0002851475146599114, + 0, + -0.030827227979898453, + -0.06409652531147003, + 0, + 0.00015167289529927075, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.013627146370708942, + -0.004393726587295532, + 0.0015328703448176384, + -0.0038613511715084314, + 0, + 0, + 0, + 0, + -0.02049136720597744, + 0, + 0, + -0.007114107254892588, + 0.0003477374848444015, + -0.001384311355650425, + 0.003183899214491248, + 0.0004558839718811214, + -0.059277813881635666, + -0.0035793157294392586, + -0.00589390005916357, + 0, + -0.0001910730206873268, + 0, + 0.0006608504336327314, + -0.0004212319909129292, + 0, + -0.003545185085386038, + -0.00327106611803174, + 0, + 0, + 0.0040074847638607025, + 0, + 0, + 0, + -0.0026069351006299257, + 0, + 0, + 0, + -0.00008433026232523844, + -0.00018646706303115934, + 0, + -0.00439279293641448, + -0.013254894874989986, + 0, + 0.050094299018383026, + -0.021308520808815956, + 0, + -0.0006410681526176631, + 0, + 0, + 0.02329532988369465, + -0.05166983604431152, + -0.002982117934152484, + -0.000014124364497547504, + -0.0020334068685770035, + 0, + 0, + 0, + -0.02020590752363205, + 0.00998645182698965, + 0, + -0.004585121292620897, + 0.005916096270084381, + 0.0018219061894342303, + 0.005700498353689909, + 0.0008085825829766691, + 0, + 0, + -0.0032405084930360317, + 0, + 0, + 0, + -0.014961971901357174, + 0, + 0, + 0, + 0, + -0.016915086656808853, + 0, + 0.016825370490550995, + 0, + -0.00311169121414423, + 0, + 0.005266942549496889, + 0, + 0, + 0, + -0.009660078212618828, + 0.0010975055629387498, + 0.006078756880015135, + 0, + 0, + 0.003166533075273037, + -0.044512320309877396, + 0.0002630578528624028, + 0, + 0, + -0.00025422731414437294, + 0, + 0, + 0, + -0.3718416392803192, + -0.0008081833366304636, + 0.00043700754758901894, + 0, + -0.023154418915510178, + 0.00004691413778346032, + 0, + 0, + -0.0002914638607762754, + 0.0006733346963301301, + 0, + 0.008972969837486744, + 0, + -0.008168808184564114, + 0, + 0, + -0.0006953283445909619, + 0, + 0, + 0, + -0.001286927843466401, + -0.017273705452680588, + 0, + 0.05898163467645645, + 0.013462062925100327, + 0, + 0, + -0.00003325308352941647, + -0.0027551515959203243, + -0.004652985371649265, + 0, + 0, + 0, + 0, + 0, + 0, + -0.21421866118907928, + 0, + 0, + 0.002191215055063367, + 0, + 0.07645706832408905, + 0.0052618952468037605, + 0, + 0, + -0.020269982516765594, + 0, + 0.013446477241814137, + 0.0068704248405992985, + 0.08710267394781113, + 0, + 0, + 0, + 0, + 0.028982989490032196, + 0, + 0, + 0.014961526729166508, + 0, + 0, + 0, + 0, + 0.011233230121433735, + 0, + -0.009112805128097534, + 0.003226917004212737, + 0, + 0, + 0, + 0.112985759973526, + 0, + 0.028253009542822838, + 0, + 0, + 0, + 0.0009787877788767219, + 0, + 0, + 0, + 0, + -0.03986968472599983, + 0, + 0, + -0.006135094445198774, + 0.04977395758032799, + 0.0397123359143734, + 0.027974072843790054, + 0, + -0.00044811973930336535, + -0.10083132237195969, + 0.000008234118467953522, + 0.06165996566414833, + 0, + 0, + 0.021058127284049988, + 0, + 0, + 0, + 0, + 0, + -0.08074336498975754, + -0.009298793971538544, + 0, + 0.012482613325119019, + 0.06513619422912598, + 0, + 0, + 0.00029019018984399736, + 0, + 0, + 0.0014882637187838554, + 0, + 0, + 0, + -0.004803473129868507, + 0, + 0, + 0.025678949430584908, + 0, + -0.04240157827734947, + 0, + 0, + 0.0015190609265118837, + 0.0006482255994342268, + 0.03654245659708977, + 0, + 0, + 0, + 0, + 0, + 0.0020186977926641703, + 0, + 0, + 0.17831696569919586, + 0, + 0, + 0, + 0.0005887048901058733, + 0.012331255711615086, + 0, + 0, + 0.11619613319635391, + 0.04687207192182541, + 0.03033648431301117, + 0, + -0.004195880610495806, + 0.00006391256465576589, + -0.03162289038300514, + 0, + 0, + 0, + 0, + 0, + 0, + 0.03672636300325394, + 0.00788492988795042, + 0, + 0, + 0.003685369621962309, + 0, + 0, + 0, + -0.010384900495409966, + 0, + -0.1327948272228241, + -0.22788244485855103, + 0, + 0.0003893508983310312, + 0, + 0, + 0, + 0.009530982933938503, + 0, + -0.0001286355109186843, + 0, + 0, + 0.0001596187794348225, + 0.011789986863732338, + -0.0022452236153185368, + 0, + -0.0014552043285220861, + 0.0002036036894423887, + -0.03003234602510929, + -0.036742936819791794, + -0.028862446546554565, + -0.003727517556399107, + 0, + 0.0011460097739472985, + 0, + -0.027142589911818504, + -0.054151974618434906, + -0.0004727205669041723, + 0, + -0.006094601005315781, + 0.00013960858632344753, + 0, + -0.0003665595140773803, + 0.00028091753483749926, + -0.17846877872943878, + -0.004990901332348585, + -0.010615025646984577, + 0, + 0.015916047617793083, + 0, + 0.0008773574372753501, + 0.004459311719983816, + 0, + -0.015235064551234245, + 0, + -0.0008741968194954097, + -0.04074608162045479, + 0.007227533031255007, + 0, + 0, + 0, + -0.007763775996863842, + 0, + 0, + -0.0011336231837049127, + 0, + -0.004542750306427479, + 0.016146792098879814, + -0.032868705689907074, + -0.013282506726682186, + 0, + 0.1884474903345108, + -0.07819699496030807, + 0, + 0.00013099861098453403, + 0.00024322106037288904, + 0, + 0.04764547944068909, + -0.09056885540485382, + 0, + -0.005007788073271513, + 0.000487087934743613, + 0, + 0, + 0, + -0.07196655869483948, + 0.007451012264937162, + 0, + -0.013892672955989838, + -0.005596193019300699, + -0.005349555052816868, + -0.00015437132969964296, + 0, + 0, + 0, + -0.00894666463136673, + 0, + 0, + 0, + -0.036862581968307495, + 0, + 0, + 0, + 0, + -0.026162482798099518, + 0, + 0.046491872519254684, + 0, + -0.030160455033183098, + 0, + -0.009029642678797245, + -0.0021479984279721975, + -0.0005375721957534552, + -0.002135993679985404, + -0.027962258085608482, + 0, + 0.0008057129452936351, + 0, + 0, + 0, + -0.26795026659965515, + 0, + 0, + 0, + 0, + 0, + -0.0027670287527143955, + -0.0002252299600513652, + -0.7548060417175293, + 0, + -0.05009680241346359, + 0, + -0.03914204612374306, + 0, + 0, + 0.016279445961117744, + 0, + 0, + 0, + 0.025662390515208244, + -0.000049459828005637974, + -0.0023572721984237432, + 0, + 0, + 0.0009027881897054613, + 0, + 0, + 0, + 0, + -0.01007400918751955, + 0.07334298640489578, + 0.15174342691898346, + 0, + 0, + 0, + 0, + 0, + 0.0007311829249374568, + 0, + 0, + 0, + 0.011839455924928188, + 0, + 0, + -0.2282165139913559, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.017542533576488495, + 0, + 0, + 0, + 0.1636323779821396, + 0, + 0.10289037227630615, + 0, + 0, + 0, + 0, + 0, + 0.024433566257357597, + 0, + 0, + 0, + 0, + 0.013018166646361351, + 0, + -0.0005916667287237942, + 0, + 0, + 0, + 0, + 0.07111621648073196, + 0.0004984873230569065, + 0.015917964279651642, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.06262800097465515, + 0.17253385484218597, + 0, + 0, + -0.0007970984443090856, + -0.1451263427734375, + 0, + 0.08718064427375793, + 0, + 0, + 0.007446629460901022, + 0, + 0, + 0, + 0, + 0, + -0.09546831995248795, + 0, + 0, + 0.06110787391662598, + 0.08931172639131546, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.005256101489067078, + -0.00553735950961709, + 0, + 0, + 0.006732907146215439, + 0, + -0.005547903478145599, + 0, + 0, + 0, + 0.01766844280064106, + -0.0034187675919383764, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.1122211441397667, + 0, + 0, + 0, + 0.009442206472158432, + 0, + 0, + 0, + 0, + 0, + 0.00800288561731577, + 0, + 0.006613056641072035, + 0, + -0.06462590396404266, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.0009047402418218553, + 0, + 0, + -0.0005877931835129857, + 0, + 0, + 0, + -0.0004729636711999774, + 0, + -0.05036322772502899, + -0.24687804281711578, + 0, + 0.001115482416935265, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.0024291854351758957, + 0, + 0, + 0, + -0.029154174029827118, + -0.011211197823286057, + 0, + 0, + 0, + 0, + 0, + -0.0075091151520609856, + 0, + 0.0037634933833032846, + 0.022711526602506638, + 0, + 0, + -0.00011145337339257821, + 0, + -0.08350298553705215, + 0, + 0, + 0, + 0, + 0, + 0.0063380529172718525, + 0, + 0, + 0, + 0, + 0, + 0, + -0.0010615212377160788, + 0, + 0, + 0, + 0.001314864493906498, + 0, + 0, + -0.0020079570822417736, + 0, + 0, + 0, + -0.095857173204422, + 0, + 0, + 0.04977884888648987, + 0.04924672096967697, + 0, + 0.00675918348133564, + 0, + 0, + 0.02823697216808796, + -0.07869893312454224, + 0, + 0, + -0.00039145027403719723, + 0, + 0, + 0, + -0.03502006456255913, + 0, + 0, + -0.004709419794380665, + -0.007543480955064297, + -0.007213911972939968, + 0.0026987697929143906, + 0, + 0, + 0, + -0.0016787010245025158, + 0, + 0, + -0.002866228111088276, + -0.04759479686617851, + 0, + 0, + 0, + 0, + -0.005348640959709883, + 0, + 0.17661413550376892, + 0, + -0.0024743194226175547, + 0, + 0.0269751138985157, + 0, + 0, + 0, + -0.025461290031671524, + 0, + 0, + 0, + 0, + 0, + -0.14607883989810944, + 0, + 0.020490022376179695, + 0.007573024369776249, + 0, + 0, + 0, + 0, + -0.8939738869667053, + 0, + -0.006900197826325893, + -0.0031849159859120846, + -0.015817783772945404, + 0, + 0, + 0, + 0, + 0, + 0.00032859406201168895, + 0.11629504710435867, + 0, + 0, + 0, + 0, + 0.0020032059401273727, + 0, + 0, + 0, + 0, + -0.02256190776824951, + 0.07616151124238968, + 0.3106333911418915, + 0, + 0, + 0, + 0, + 0, + -0.014044971205294132, + 0, + 0, + 0, + 0, + 0, + 0, + -0.3483165502548218, + 0, + 0, + 0, + 0, + 0.05930393189191818, + 0, + 0, + 0.004992437083274126, + -0.08404884487390518, + 0, + 0, + 0, + 0.16281214356422424, + 0, + 0.28443410992622375, + 0, + 0, + 0.0014393558958545327, + 0, + 0, + -0.009063852950930595, + 0, + 0, + 0, + 0, + 0.001169737195596099, + 0, + 0, + 0, + 0, + 0, + 0, + 0.06898342072963715, + 0, + 0.007991905324161053, + 0, + 0.006260615773499012, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.037955716252326965, + 0.3505173921585083, + 0, + 0, + 0, + -0.338177889585495, + 0, + 0.158599853515625, + 0, + 0, + 0.01131439208984375, + 0, + 0.006751265376806259, + 0, + 0, + 0, + -0.04573351889848709, + 0, + 0, + 0.04386100172996521, + 0.11277603358030319, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.0003205372195225209, + 0, + 0, + 0, + -0.005409737583249807, + -0.009204162284731865, + 0, + 0, + 0, + 0.004804544150829315, + -0.005810749251395464, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.09645535796880722, + 0, + 0, + 0, + 0.032931435853242874, + 0, + 0, + 0, + 0, + 0.028524864464998245, + 0.09402520954608917, + 0, + 0.008998546749353409, + 0, + -0.03251685947179794, + 0, + 0, + 0, + 0, + 0, + 0, + 0.037343256175518036, + 0.0045840502716600895, + 0, + 0, + 0.009021798148751259, + 0, + -0.004217533860355616, + 0, + 0.0025705555453896523, + 0, + -0.035309672355651855, + -0.09942735731601715, + -0.004700342193245888, + 0, + 0, + 0, + 0.018288278952240944, + 0.0004021169152110815, + 0, + 0, + 0, + 0, + -0.005593586713075638, + 0.0018821493722498417, + 0, + 0, + 0, + 0, + -0.04561242088675499, + -0.01815006509423256, + 0.016583485528826714, + -0.020843051373958588, + 0, + 0, + 0, + -0.006372869946062565, + 0.04272369295358658, + 0, + 0.0013309348141774535, + 0, + 0, + 0, + -0.0031638317741453648, + 0, + 0.08714215457439423, + 0, + -0.005442100111395121, + -0.00039313771412707865, + 0.0014464370906352997, + 0, + 0, + 0, + 0, + -0.03132649511098862, + 0, + 0, + -0.007972904480993748, + 0.01753396727144718, + 0, + -0.0011563192820176482, + 0, + 0.0017362519865855575, + 0, + -0.0004587600124068558, + 0, + 0.0038881096988916397, + -0.008516360074281693, + 0, + -0.008183307014405727, + -0.010095844976603985, + 0, + 0.10722006857395172, + -0.002898464212194085, + 0, + 0.0012827662285417318, + 0.00004252225699019618, + 0, + 0.07567721605300903, + -0.030121177434921265, + 0, + -0.0010666534071788192, + 0.0006539365276694298, + 0, + 0, + -0.0011567147448658943, + -0.021622339263558388, + 0.028687214478850365, + 0, + 0, + 0.018764594569802284, + 0.010613140650093555, + 0.019510075449943542, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.005288010463118553, + 0, + 0, + -0.016743114218115807, + -0.015873711556196213, + -0.009877816773951054, + 0.0003150522243231535, + 0.023689158260822296, + 0, + -0.00033418016391806304, + 0, + 0.04904749244451523, + 0, + 0, + 0, + 0.0006500506424345076, + 0.000622213410679251, + 0.00738720316439867, + 0, + 0.0012243357487022877, + 0.009066173806786537, + -0.12073952704668045, + 0.0010678119724616408, + 0.006296947598457336, + 0.002682592486962676, + -0.00444818427786231, + -0.0023324599023908377, + 0, + 0, + -0.5609893798828125, + -0.008780602365732193, + 0.015986066311597824, + 0, + -0.02213476411998272, + 0, + 0, + 0.0006705078994855285, + -0.0011221399763599038, + 0, + 0, + 0.025299811735749245, + 0, + -0.008218510076403618, + -0.0034782839938998222, + 0.0022423912305384874, + 0, + 0, + 0, + 0, + 0, + -0.05859537422657013, + 0.0421387143433094, + 0.26256099343299866, + 0, + 0, + 0, + 0, + 0, + -0.10330676287412643, + 0.012355834245681763, + 0.0013472040882334113, + 0, + 0.019914263859391212, + 0.005261276848614216, + 0.001149827498011291, + -0.03320133313536644, + 0, + 0, + 0, + 0, + 0.32198745012283325, + 0, + 0.05401667580008507, + 0.04610951617360115, + -0.2326284795999527, + 0, + 0.0000856258993735537, + 0, + 0.074106365442276, + 0, + 0.044469863176345825, + 0, + 0, + 0, + -0.006453251000493765, + 0, + -0.018431225791573524, + 0, + 0, + 0, + 0, + -0.005704954732209444, + 0, + 0, + 0, + 0, + 0, + 0, + 0.13457728922367096, + 0, + -0.029186677187681198, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -0.022995056584477425, + 0, + 0, + 0, + -0.09004921466112137, + 0.24257110059261322, + 0.02852930873632431, + 0, + 0.021270141005516052, + -0.13564155995845795, + -0.03098711557686329, + 0, + 0, + 0, + 0.0486220121383667, + -0.001395023544318974, + 0, + 0.04929636791348457, + 0, + 0, + -0.13068373501300812, + 0, + 0, + 0.016955919563770294, + 0.03848254308104515, + 0, + -0.011160435155034065, + 0, + 0, + -0.0002991429646499455, + 0, + 0, + 0, + 0.01138608530163765, + -0.020150866359472275, + 0, + 0, + -0.007353566121309996, + -0.021389631554484367, + -0.042083244770765305, + 0.13586723804473877, + 0.005315479822456837, + 0, + 0.008157049305737019, + 0.022239860147237778, + 0, + 0, + 0.01896926946938038, + 0.0018052944215014577, + 0.0016496418975293636, + 0.0005593635141849518, + 0, + 0, + 0.07655386626720428, + 0, + 0, + 0, + 0.02781328558921814, + 0.04012482985854149, + 0, + 0, + 0, + 0.10631410032510757, + 0.03608629107475281, + 0, + -0.02651066705584526, + 0, + -0.0690990686416626, + 0, + 0, + 0, + 0, + 0, + 0, + 0.1022648885846138 + ], + "yaxis": "y" + } + ], + "layout": { + "legend": { + "tracegroupgap": 0 + }, + "shapes": [ + { + "line": { + "color": "gray", + "dash": "dot", + "width": 1 + }, + "type": "line", + "x0": -0.8653240203857422, + "x1": 0.31668567657470703, + "y0": -0.8653240203857422, + "y1": 0.31668567657470703 + } + ], + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Attribution vs Activation Patching Per SAE feature (L5 S2 Pos, all prompts)" + }, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Activation Patch" + } + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Attribution Patch" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig = scatter(\n", + " y=sae_act_attr[:, s2_pos, all_live_features].flatten(), \n", + " x=causal_effects.flatten(),\n", + " title=\"Attribution vs Activation Patching Per SAE feature (L5 S2 Pos, all prompts)\",\n", + " xaxis=\"Activation Patch\",\n", + " yaxis=\"Attribution Patch\",\n", + " return_fig=True\n", + ")\n", + "fig.add_shape(\n", + " type='line',\n", + " x0=causal_effects.min(),\n", + " y0=causal_effects.min(),\n", + " x1=causal_effects.max(),\n", + " y1=causal_effects.max(),\n", + " line=dict(\n", + " color='gray',\n", + " width=1,\n", + " dash='dot'\n", + " )\n", + ")\n", + "fig.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/unit/test_hooked_sae.py b/tests/unit/test_hooked_sae.py new file mode 100644 index 000000000..de311772e --- /dev/null +++ b/tests/unit/test_hooked_sae.py @@ -0,0 +1,191 @@ +import einops +import pytest +import torch + +from transformer_lens import HookedSAE, HookedSAEConfig, HookedSAETransformer + +MODEL = "solu-1l" +prompt = "Hello World!" + + +class Counter: + def __init__(self): + self.count = 0 + + def inc(self, *args, **kwargs): + self.count += 1 + + +@pytest.fixture(scope="module") +def model(): + model = HookedSAETransformer.from_pretrained(MODEL) + yield model + model.reset_saes() + + +def get_sae_config(model, act_name): + site_to_size = { + "hook_z": model.cfg.d_head * model.cfg.n_heads, + "hook_mlp_out": model.cfg.d_model, + "hook_resid_pre": model.cfg.d_model, + "hook_post": model.cfg.d_mlp, + } + site = act_name.split(".")[-1] + d_in = site_to_size[site] + return HookedSAEConfig(d_in=d_in, d_sae=d_in * 2, hook_name=act_name) + + +@pytest.mark.parametrize( + "act_name", + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], +) +def test_forward_reconstructs_input(model, act_name): + """Verfiy that the HookedSAE returns an output with the same shape as the input activations.""" + sae_cfg = get_sae_config(model, act_name) + hooked_sae = HookedSAE(sae_cfg) + + _, cache = model.run_with_cache(prompt, names_filter=act_name) + x = cache[act_name] + + sae_output = hooked_sae(x) + assert sae_output.shape == x.shape + + +@pytest.mark.parametrize( + "act_name", + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], +) +def test_run_with_cache(model, act_name): + """Verifies that run_with_cache caches SAE activations""" + sae_cfg = get_sae_config(model, act_name) + hooked_sae = HookedSAE(sae_cfg) + + _, cache = model.run_with_cache(prompt, names_filter=act_name) + x = cache[act_name] + + sae_output, cache = hooked_sae.run_with_cache(x) + assert sae_output.shape == x.shape + + assert "hook_sae_input" in cache + assert "hook_sae_acts_pre" in cache + assert "hook_sae_acts_post" in cache + assert "hook_sae_recons" in cache + assert "hook_sae_output" in cache + + +@pytest.mark.parametrize( + "act_name", + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], +) +def test_run_with_hooks(model, act_name): + """Verifies that run_with_hooks works with SAE activations""" + c = Counter() + sae_cfg = get_sae_config(model, act_name) + hooked_sae = HookedSAE(sae_cfg) + + _, cache = model.run_with_cache(prompt, names_filter=act_name) + x = cache[act_name] + + sae_hooks = [ + "hook_sae_input", + "hook_sae_acts_pre", + "hook_sae_acts_post", + "hook_sae_recons", + "hook_sae_output", + ] + + sae_output = hooked_sae.run_with_hooks( + x, fwd_hooks=[(sae_hook_name, c.inc) for sae_hook_name in sae_hooks] + ) + assert sae_output.shape == x.shape + + assert c.count == len(sae_hooks) + + +@pytest.mark.parametrize( + "act_name", + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], +) +def test_error_term(model, act_name): + """Verifies that that if we use error_terms, HookedSAE returns an output that is equal to the input activations.""" + sae_cfg = get_sae_config(model, act_name) + sae_cfg.use_error_term = True + hooked_sae = HookedSAE(sae_cfg) + + _, cache = model.run_with_cache(prompt, names_filter=act_name) + x = cache[act_name] + + sae_output = hooked_sae(x) + assert sae_output.shape == x.shape + assert torch.allclose(sae_output, x, atol=1e-6) + + +# %% +@pytest.mark.parametrize( + "act_name", + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], +) +def test_feature_grads_with_error_term(model, act_name): + """Verifies that pytorch backward computes the correct feature gradients when using error_terms. Motivated by the need to compute feature gradients for attribution patching.""" + + # Load SAE + sae_cfg = get_sae_config(model, act_name) + sae_cfg.use_error_term = True + hooked_sae = HookedSAE(sae_cfg) + + # Get input activations + _, cache = model.run_with_cache(prompt, names_filter=act_name) + x = cache[act_name] + + # Cache gradients with respect to feature acts + hooked_sae.reset_hooks() + grad_cache = {} + + def backward_cache_hook(act, hook): + grad_cache[hook.name] = act.detach() + + hooked_sae.add_hook("hook_sae_acts_post", backward_cache_hook, "bwd") + hooked_sae.add_hook("hook_sae_output", backward_cache_hook, "bwd") + + sae_output = hooked_sae(x) + assert torch.allclose(sae_output, x, atol=1e-6) + value = sae_output.sum() + value.backward() + hooked_sae.reset_hooks() + + # Compute gradient analytically + if act_name.endswith("hook_z"): + reshaped_output_grad = einops.rearrange( + grad_cache["hook_sae_output"], "... n_heads d_head -> ... (n_heads d_head)" + ) + analytic_grad = reshaped_output_grad @ hooked_sae.W_dec.T + else: + analytic_grad = grad_cache["hook_sae_output"] @ hooked_sae.W_dec.T + + # Compare analytic gradient with pytorch computed gradient + assert torch.allclose(grad_cache["hook_sae_acts_post"], analytic_grad, atol=1e-6) diff --git a/tests/unit/test_hooked_sae_transformer.py b/tests/unit/test_hooked_sae_transformer.py new file mode 100644 index 000000000..bfb428c8e --- /dev/null +++ b/tests/unit/test_hooked_sae_transformer.py @@ -0,0 +1,515 @@ +import pytest +import torch + +from transformer_lens import ( + HookedSAE, + HookedSAEConfig, + HookedSAETransformer, + HookedTransformer, +) +from transformer_lens.ActivationCache import ActivationCache +from transformer_lens.hook_points import HookPoint # Hooking utilities +from transformer_lens.HookedSAETransformer import get_deep_attr + +MODEL = "solu-1l" +prompt = "Hello World!" + + +class Counter: + def __init__(self): + self.count = 0 + + def inc(self, *args, **kwargs): + self.count += 1 + + +@pytest.fixture(scope="module") +def original_logits(): + original_model = HookedTransformer.from_pretrained(MODEL) + return original_model(prompt) + + +@pytest.fixture(scope="module") +def model(): + model = HookedSAETransformer.from_pretrained(MODEL) + yield model + model.reset_saes() + + +def get_sae_config(model, act_name): + site_to_size = { + "hook_z": model.cfg.d_head * model.cfg.n_heads, + "hook_mlp_out": model.cfg.d_model, + "hook_resid_pre": model.cfg.d_model, + "hook_post": model.cfg.d_mlp, + } + site = act_name.split(".")[-1] + d_in = site_to_size[site] + return HookedSAEConfig(d_in=d_in, d_sae=d_in * 2, hook_name=act_name) + + +def test_model_with_no_saes_matches_original_model(model, original_logits): + """Verifies that HookedSAETransformer behaves like a normal HookedTransformer model when no SAEs are attached.""" + assert len(model.acts_to_saes) == 0 + logits = model(prompt) + assert torch.allclose(original_logits, logits) + + +@pytest.mark.parametrize( + "act_name", + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], +) +def test_model_with_saes_does_not_match_original_model(model, act_name, original_logits): + """Verifies that the attached (and turned on) SAEs actually affect the models output logits""" + assert len(model.acts_to_saes) == 0 + sae_cfg = get_sae_config(model, act_name) + hooked_sae = HookedSAE(sae_cfg) + model.add_sae(hooked_sae) + assert len(model.acts_to_saes) == 1 + logits_with_saes = model(prompt) + assert not torch.allclose(original_logits, logits_with_saes) + model.reset_saes() + + +@pytest.mark.parametrize( + "act_name", + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], +) +def test_add_sae(model, act_name): + """Verifies that add_sae correctly updates the model's acts_to_saes dictionary and replaces the HookPoint.""" + sae_cfg = get_sae_config(model, act_name) + hooked_sae = HookedSAE(sae_cfg) + model.add_sae(hooked_sae) + assert len(model.acts_to_saes) == 1 + assert model.acts_to_saes[act_name] == hooked_sae + assert get_deep_attr(model, act_name) == hooked_sae + model.reset_saes() + + +@pytest.mark.parametrize( + "act_name", + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], +) +def test_add_sae_overwrites_prev_sae(model, act_name): + """Verifies that add_sae correctly updates the model's acts_to_saes dictionary and replaces the HookPoint.""" + prev_sae_cfg = get_sae_config(model, act_name) + prev_hooked_sae = HookedSAE(prev_sae_cfg) + model.add_sae(prev_hooked_sae) + assert len(model.acts_to_saes) == 1 + assert model.acts_to_saes[act_name] == prev_hooked_sae + assert get_deep_attr(model, act_name) == prev_hooked_sae + + sae_cfg = get_sae_config(model, act_name) + hooked_sae = HookedSAE(sae_cfg) + model.add_sae(hooked_sae) + assert len(model.acts_to_saes) == 1 + assert model.acts_to_saes[act_name] == hooked_sae + assert get_deep_attr(model, act_name) == hooked_sae + model.reset_saes() + + +@pytest.mark.parametrize( + "act_name", + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], +) +def test_reset_sae_removes_sae_by_default(model, act_name): + """Verifies that reset_sae correctly removes the SAE from the model's acts_to_saes dictionary and replaces the HookedSAE with a HookPoint.""" + sae_cfg = get_sae_config(model, act_name) + hooked_sae = HookedSAE(sae_cfg) + model.add_sae(hooked_sae) + assert len(model.acts_to_saes) == 1 + assert model.acts_to_saes[act_name] == hooked_sae + assert get_deep_attr(model, act_name) == hooked_sae + model._reset_sae(act_name) + assert len(model.acts_to_saes) == 0 + assert isinstance(get_deep_attr(model, act_name), HookPoint) + model.reset_saes() + + +@pytest.mark.parametrize( + "act_name", + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], +) +def test_reset_sae_replaces_sae(model, act_name): + """Verifies that reset_sae correctly removes the SAE from the model's acts_to_saes dictionary and replaces the HookedSAE with a HookPoint.""" + sae_cfg = get_sae_config(model, act_name) + hooked_sae = HookedSAE(sae_cfg) + + prev_sae_cfg = get_sae_config(model, act_name) + prev_sae = HookedSAE(prev_sae_cfg) + + model.add_sae(hooked_sae) + assert len(model.acts_to_saes) == 1 + assert model.acts_to_saes[act_name] == hooked_sae + assert get_deep_attr(model, act_name) == hooked_sae + model._reset_sae(act_name, prev_sae) + assert len(model.acts_to_saes) == 1 + assert get_deep_attr(model, act_name) == prev_sae + model.reset_saes() + + +@pytest.mark.parametrize( + "act_names", + [ + ["blocks.0.attn.hook_z"], + ["blocks.0.hook_mlp_out"], + ["blocks.0.mlp.hook_post"], + ["blocks.0.hook_resid_pre"], + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], + ], +) +def test_reset_saes_removes_all_saes_by_default(model, act_names): + """Verifies that reset_saes correctly removes all SAEs from the model's acts_to_saes dictionary and replaces the HookedSAEs with HookPoints.""" + sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names] + hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs] + for hooked_sae in hooked_saes: + model.add_sae(hooked_sae) + assert len(model.acts_to_saes) == len(act_names) + for act_name, hooked_sae in zip(act_names, hooked_saes): + assert model.acts_to_saes[act_name] == hooked_sae + assert get_deep_attr(model, act_name) == hooked_sae + model.reset_saes() + assert len(model.acts_to_saes) == 0 + for act_name in act_names: + assert isinstance(get_deep_attr(model, act_name), HookPoint) + model.reset_saes() + + +@pytest.mark.parametrize( + "act_names", + [ + ["blocks.0.attn.hook_z"], + ["blocks.0.hook_mlp_out"], + ["blocks.0.mlp.hook_post"], + ["blocks.0.hook_resid_pre"], + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], + ], +) +def test_reset_saes_replaces_saes(model, act_names): + """Verifies that reset_saes correctly removes all SAEs from the model's acts_to_saes dictionary and replaces the HookedSAEs with HookPoints.""" + sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names] + hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs] + for hooked_sae in hooked_saes: + model.add_sae(hooked_sae) + + prev_sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names] + prev_hooked_saes = [HookedSAE(prev_sae_cfg) for prev_sae_cfg in prev_sae_cfgs] + + assert len(model.acts_to_saes) == len(act_names) + for act_name, hooked_sae in zip(act_names, hooked_saes): + assert model.acts_to_saes[act_name] == hooked_sae + assert get_deep_attr(model, act_name) == hooked_sae + model.reset_saes(act_names, prev_hooked_saes) + assert len(model.acts_to_saes) == len(prev_hooked_saes) + for act_name, prev_hooked_sae in zip(act_names, prev_hooked_saes): + assert get_deep_attr(model, act_name) == prev_hooked_sae + model.reset_saes() + + +@pytest.mark.parametrize( + "act_names", + [ + ["blocks.0.attn.hook_z"], + ["blocks.0.hook_mlp_out"], + ["blocks.0.mlp.hook_post"], + ["blocks.0.hook_resid_pre"], + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], + ], +) +def test_saes_context_manager_removes_saes_after(model, act_names): + """Verifies that the model.saes context manager successfully adds the SAEs for the specified activation name in the context manager and resets off after the context manager exits.""" + sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names] + hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs] + assert len(model.acts_to_saes) == 0 + for act_name in act_names: + assert isinstance(get_deep_attr(model, act_name), HookPoint) + with model.saes(saes=hooked_saes): + for act_name, hooked_sae in zip(act_names, hooked_saes): + assert model.acts_to_saes[act_name] == hooked_sae + assert isinstance(get_deep_attr(model, act_name), HookedSAE) + assert get_deep_attr(model, act_name) == hooked_sae + model.forward(prompt) + assert len(model.acts_to_saes) == 0 + for act_name in act_names: + assert isinstance(get_deep_attr(model, act_name), HookPoint) + model.reset_saes() + + +@pytest.mark.parametrize( + "act_names", + [ + ["blocks.0.attn.hook_z"], + ["blocks.0.hook_mlp_out"], + ["blocks.0.mlp.hook_post"], + ["blocks.0.hook_resid_pre"], + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], + ], +) +def test_saes_context_manager_restores_previous_sae_state(model, act_names): + """Verifies that the model.saes context manager successfully adds the SAEs for the specified activation name in the context manager and resets off after the context manager exits.""" + # First add SAEs statefully + prev_sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names] + prev_hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in prev_sae_cfgs] + for act_name, prev_hooked_sae in zip(act_names, prev_hooked_saes): + model.add_sae(prev_hooked_sae) + assert get_deep_attr(model, act_name) == prev_hooked_sae + assert len(model.acts_to_saes) == len(prev_hooked_saes) + + # Now temporarily run with new SAEs + sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names] + hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs] + with model.saes(saes=hooked_saes): + for act_name, hooked_sae in zip(act_names, hooked_saes): + assert model.acts_to_saes[act_name] == hooked_sae + assert isinstance(get_deep_attr(model, act_name), HookedSAE) + assert get_deep_attr(model, act_name) == hooked_sae + model.forward(prompt) + + # Check that the previously attached SAEs have been restored + assert len(model.acts_to_saes) == len(prev_hooked_saes) + for act_name, prev_hooked_sae in zip(act_names, prev_hooked_saes): + assert isinstance(get_deep_attr(model, act_name), HookedSAE) + assert get_deep_attr(model, act_name) == prev_hooked_sae + model.reset_saes() + + +@pytest.mark.parametrize( + "act_names", + [ + ["blocks.0.attn.hook_z"], + ["blocks.0.hook_mlp_out"], + ["blocks.0.mlp.hook_post"], + ["blocks.0.hook_resid_pre"], + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], + ], +) +def test_saes_context_manager_run_with_cache(model, act_names): + """Verifies that the model.run_with_cache method works correctly in the context manager.""" + sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names] + hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs] + assert len(model.acts_to_saes) == 0 + for act_name in act_names: + assert isinstance(get_deep_attr(model, act_name), HookPoint) + with model.saes(saes=hooked_saes): + for act_name, hooked_sae in zip(act_names, hooked_saes): + assert model.acts_to_saes[act_name] == hooked_sae + assert isinstance(get_deep_attr(model, act_name), HookedSAE) + assert get_deep_attr(model, act_name) == hooked_sae + model.run_with_cache(prompt) + assert len(model.acts_to_saes) == 0 + for act_name in act_names: + assert isinstance(get_deep_attr(model, act_name), HookPoint) + model.reset_saes() + + +@pytest.mark.parametrize( + "act_names", + [ + ["blocks.0.attn.hook_z"], + ["blocks.0.hook_mlp_out"], + ["blocks.0.mlp.hook_post"], + ["blocks.0.hook_resid_pre"], + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], + ], +) +def test_run_with_saes(model, act_names, original_logits): + """Verifies that the model.run_with_saes method works correctly. The logits with SAEs should be different from the original logits, but the SAE should be removed immediately after the forward pass.""" + sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names] + hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs] + assert len(model.acts_to_saes) == 0 + logits_with_saes = model.run_with_saes(prompt, saes=hooked_saes) + assert not torch.allclose(logits_with_saes, original_logits) + assert len(model.acts_to_saes) == 0 + for act_name in act_names: + assert isinstance(get_deep_attr(model, act_name), HookPoint) + model.reset_saes() + + +@pytest.mark.parametrize( + "act_names", + [ + ["blocks.0.attn.hook_z"], + ["blocks.0.hook_mlp_out"], + ["blocks.0.mlp.hook_post"], + ["blocks.0.hook_resid_pre"], + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], + ], +) +def test_run_with_cache(model, act_names, original_logits): + """Verifies that the model.run_with_cache method works correctly. The logits with SAEs should be different from the original logits and the cache should contain SAE activations for the attached SAE.""" + sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names] + hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs] + for hooked_sae in hooked_saes: + model.add_sae(hooked_sae) + assert len(model.acts_to_saes) == len(hooked_saes) + logits_with_saes, cache = model.run_with_cache(prompt) + assert not torch.allclose(logits_with_saes, original_logits) + assert isinstance(cache, ActivationCache) + for act_name, hooked_sae in zip(act_names, hooked_saes): + assert act_name + ".hook_sae_acts_post" in cache + assert isinstance(get_deep_attr(model, act_name), HookedSAE) + assert get_deep_attr(model, act_name) == hooked_sae + model.reset_saes() + + +@pytest.mark.parametrize( + "act_names", + [ + ["blocks.0.attn.hook_z"], + ["blocks.0.hook_mlp_out"], + ["blocks.0.mlp.hook_post"], + ["blocks.0.hook_resid_pre"], + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], + ], +) +def test_run_with_cache_with_saes(model, act_names, original_logits): + """Verifies that the model.run_with_cache_with_saes method works correctly. The logits with SAEs should be different from the original logits and the cache should contain SAE activations for the attached SAE.""" + sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names] + hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs] + logits_with_saes, cache = model.run_with_cache_with_saes(prompt, saes=hooked_saes) + assert not torch.allclose(logits_with_saes, original_logits) + assert isinstance(cache, ActivationCache) + + assert len(model.acts_to_saes) == 0 + for act_name, hooked_sae in zip(act_names, hooked_saes): + assert act_name + ".hook_sae_acts_post" in cache + assert isinstance(get_deep_attr(model, act_name), HookPoint) + model.reset_saes() + + +@pytest.mark.parametrize( + "act_names", + [ + ["blocks.0.attn.hook_z"], + ["blocks.0.hook_mlp_out"], + ["blocks.0.mlp.hook_post"], + ["blocks.0.hook_resid_pre"], + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], + ], +) +def test_run_with_hooks(model, act_names, original_logits): + """Verifies that the model.run_with_hooks method works correctly when SAEs are attached. The count should be incremented by 1 when the hooked SAE is called, and the SAE should stay attached after the forward pass""" + c = Counter() + sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names] + hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs] + + for hooked_sae in hooked_saes: + model.add_sae(hooked_sae) + + logits_with_saes = model.run_with_hooks( + prompt, fwd_hooks=[(act_name + ".hook_sae_acts_post", c.inc) for act_name in act_names] + ) + assert not torch.allclose(logits_with_saes, original_logits) + + for act_name, hooked_sae in zip(act_names, hooked_saes): + assert isinstance(get_deep_attr(model, act_name), HookedSAE) + assert get_deep_attr(model, act_name) == hooked_sae + assert c.count == len(act_names) + model.reset_saes() + model.remove_all_hook_fns(including_permanent=True) + + +@pytest.mark.parametrize( + "act_names", + [ + ["blocks.0.attn.hook_z"], + ["blocks.0.hook_mlp_out"], + ["blocks.0.mlp.hook_post"], + ["blocks.0.hook_resid_pre"], + [ + "blocks.0.attn.hook_z", + "blocks.0.hook_mlp_out", + "blocks.0.mlp.hook_post", + "blocks.0.hook_resid_pre", + ], + ], +) +def test_run_with_hooks_with_saes(model, act_names, original_logits): + """Verifies that the model.run_with_hooks_with_saes method works correctly when SAEs are attached. The count should be incremented by 1 when the hooked SAE is called, but the SAE should be removed immediately after the forward pass.""" + c = Counter() + sae_cfgs = [get_sae_config(model, act_name) for act_name in act_names] + hooked_saes = [HookedSAE(sae_cfg) for sae_cfg in sae_cfgs] + + logits_with_saes = model.run_with_hooks_with_saes( + prompt, + saes=hooked_saes, + fwd_hooks=[(act_name + ".hook_sae_acts_post", c.inc) for act_name in act_names], + ) + assert not torch.allclose(logits_with_saes, original_logits) + assert c.count == len(act_names) + + assert len(model.acts_to_saes) == 0 + for act_name in act_names: + assert isinstance(get_deep_attr(model, act_name), HookPoint) + model.reset_saes() + model.remove_all_hook_fns(including_permanent=True) diff --git a/transformer_lens/HookedSAE.py b/transformer_lens/HookedSAE.py new file mode 100644 index 000000000..df9b29d05 --- /dev/null +++ b/transformer_lens/HookedSAE.py @@ -0,0 +1,118 @@ +from typing import Dict, Union + +import einops +import torch +import torch.nn.functional as F +from jaxtyping import Float +from torch import nn + +from transformer_lens.hook_points import ( # Hooking utilities + HookedRootModule, + HookPoint, +) +from transformer_lens.HookedSAEConfig import HookedSAEConfig + + +class HookedSAE(HookedRootModule): + """Hooked SAE. + + Implements a standard SAE with a TransformerLens hooks for SAE activations + + Designed for inference / analysis, not training. For training, see Joseph Bloom's SAELens (https://github.com/jbloomAus/SAELens) + + Note that HookedSAETransformer is fairly modular, and doesn't make strong assumptions about the architecture of the SAEs that get attached. We provide HookedSAE as a useful default class, but if you want to eg experiment with other SAE architectures, you can just copy the HookedSAE code into a notebook, edit it, and add instances of the new SAE class to a HookedSAETransformer (e.g. with HookedSAETransformer.add_sae(sae)) + """ + + def __init__(self, cfg: Union[HookedSAEConfig, Dict]): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedSAEConfig(**cfg) + elif isinstance(cfg, str): + raise ValueError("Please pass in a config dictionary or HookedSAEConfig object.") + self.cfg = cfg + + self.W_enc = nn.Parameter( + torch.nn.init.kaiming_uniform_( + torch.empty(self.cfg.d_in, self.cfg.d_sae, dtype=self.cfg.dtype) + ) + ) + self.W_dec = nn.Parameter( + torch.nn.init.kaiming_uniform_( + torch.empty(self.cfg.d_sae, self.cfg.d_in, dtype=self.cfg.dtype) + ) + ) + self.b_enc = nn.Parameter(torch.zeros(self.cfg.d_sae, dtype=self.cfg.dtype)) + self.b_dec = nn.Parameter(torch.zeros(self.cfg.d_in, dtype=self.cfg.dtype)) + + self.hook_sae_input = HookPoint() + self.hook_sae_acts_pre = HookPoint() + self.hook_sae_acts_post = HookPoint() + self.hook_sae_recons = HookPoint() + self.hook_sae_error = HookPoint() + self.hook_sae_output = HookPoint() + + self.to(self.cfg.device) + self.setup() + + def forward(self, input: Float[torch.Tensor, "... d_in"]) -> Float[torch.Tensor, "... d_in"]: + """SAE Forward Pass. + + Args: + input: The input tensor of activations to the SAE. Shape [..., d_in]. + Also supports hook_z activations of shape [..., n_heads, d_head], where n_heads * d_head = d_in, for attention output (hook_z) SAEs. + + Returns: + output: The reconstructed output tensor from the SAE, with the error term optionally added. Same shape as input (eg [..., d_in]) + """ + self.hook_sae_input(input) + if input.shape[-1] == self.cfg.d_in: + x = input + else: + # Assume this this is an attention output (hook_z) SAE + assert self.cfg.hook_name.endswith( + "_z" + ), f"You passed in an input shape {input.shape} does not match SAE input size {self.cfg.d_in} for hook_name {self.cfg.hook_name}. This is only supported for attn output (hook_z) SAEs." + x = einops.rearrange(input, "... n_heads d_head -> ... (n_heads d_head)") + assert ( + x.shape[-1] == self.cfg.d_in + ), f"Input shape {x.shape} does not match SAE input size {self.cfg.d_in}" + + x_cent = x - self.b_dec + # WARNING: if editing this block of code, also edit the error computation inside `if self.cfg.use_error_term` + sae_acts_pre = self.hook_sae_acts_pre( + einops.einsum(x_cent, self.W_enc, "... d_in, d_in d_sae -> ... d_sae") + + self.b_enc # [..., d_sae] + ) + sae_acts_post = self.hook_sae_acts_post(F.relu(sae_acts_pre)) # [..., d_sae] + x_reconstruct = self.hook_sae_recons( + ( + einops.einsum(sae_acts_post, self.W_dec, "... d_sae, d_sae d_in -> ... d_in") + + self.b_dec + ).reshape(input.shape) + ) + # END WARNING + + if self.cfg.use_error_term: + with torch.no_grad(): + # Recompute everything without hooks to get true error term + # Otherwise, the output with error term will always equal input, even for causal interventions that affect x_reconstruct + # This is in a no_grad context to detach the error, so we can compute SAE feature gradients (eg for attribution patching). See A.3 in https://arxiv.org/pdf/2403.19647.pdf for more detail + # NOTE: we can't just use `sae_error = input - x_reconstruct.detach()` or something simpler, since this would mean intervening on features would mean ablating features still results in perfect reconstruction. + sae_acts_pre_clean = ( + einops.einsum(x_cent, self.W_enc, "... d_in, d_in d_sae -> ... d_sae") + + self.b_enc + ) # [..., d_sae] + sae_acts_post_clean = F.relu(sae_acts_pre_clean) + x_reconstruct_clean = ( + einops.einsum( + sae_acts_post_clean, + self.W_dec, + "... d_sae, d_sae d_in -> ... d_in", + ) + + self.b_dec + ).reshape(input.shape) + + sae_error = self.hook_sae_error(input - x_reconstruct_clean) + return self.hook_sae_output(x_reconstruct + sae_error) + + return self.hook_sae_output(x_reconstruct) diff --git a/transformer_lens/HookedSAEConfig.py b/transformer_lens/HookedSAEConfig.py new file mode 100644 index 000000000..2892329e4 --- /dev/null +++ b/transformer_lens/HookedSAEConfig.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import pprint +import random +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import numpy as np +import torch + +from transformer_lens import utils + + +@dataclass +class HookedSAEConfig: + """ + Configuration class to store the configuration of a HookedSAE model. + + Args: + d_sae (int): The size of the dictionary. + d_in (int): The dimension of the input activations. + hook_name (str): The hook name of the activation the SAE was trained on (eg. blocks.0.attn.hook_z) + use_error_term (bool): Whether to use the error term in the loss function. Defaults to False. + dtype (torch.dtype, *optional*): The SAE's dtype. Defaults to torch.float32. + seed (int, *optional*): The seed to use for the SAE. + Used to set sources of randomness (Python, PyTorch and + NumPy) and to initialize weights. Defaults to None. We recommend setting a seed, so your experiments are reproducible. + device(str): The device to use for the SAE. Defaults to 'cuda' if + available, else 'cpu'. + """ + + d_sae: int + d_in: int + hook_name: str + use_error_term: bool = False + dtype: torch.dtype = torch.float32 + seed: Optional[int] = None + device: Optional[str] = None + + def __post_init__(self): + if self.seed is not None: + self.set_seed_everywhere(self.seed) + + if self.device is None: + self.device = utils.get_device() + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any]) -> HookedSAEConfig: + """ + Instantiates a `HookedSAEConfig` from a Python dictionary of + parameters. + """ + return cls(**config_dict) + + def to_dict(self): + return self.__dict__ + + def __repr__(self): + return "HookedSAEConfig:\n" + pprint.pformat(self.to_dict()) + + def set_seed_everywhere(self, seed: int): + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) diff --git a/transformer_lens/HookedSAETransformer.py b/transformer_lens/HookedSAETransformer.py new file mode 100644 index 000000000..47e88ebb9 --- /dev/null +++ b/transformer_lens/HookedSAETransformer.py @@ -0,0 +1,290 @@ +import logging +from contextlib import contextmanager +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from jaxtyping import Float + +from transformer_lens.ActivationCache import ActivationCache +from transformer_lens.hook_points import HookPoint # Hooking utilities +from transformer_lens.HookedSAE import HookedSAE +from transformer_lens.HookedTransformer import HookedTransformer + +SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor +LossPerToken = Float[torch.Tensor, "batch pos-1"] +Loss = Union[SingleLoss, LossPerToken] + + +def get_deep_attr(obj: Any, path: str): + """Helper function to get a nested attribute from a object. + In practice used to access HookedTransformer HookPoints (eg model.blocks[0].attn.hook_z) + + Args: + obj: Any object. In practice, this is a HookedTransformer (or subclass) + path: str. The path to the attribute you want to access. (eg "blocks.0.attn.hook_z") + + returns: + Any. The attribute at the end of the path + """ + parts = path.split(".") + # Navigate to the last component in the path + for part in parts: + if part.isdigit(): # This is a list index + obj = obj[int(part)] + else: # This is an attribute + obj = getattr(obj, part) + return obj + + +def set_deep_attr(obj: Any, path: str, value: Any): + """Helper function to change the value of a nested attribute from a object. + In practice used to swap HookedTransformer HookPoints (eg model.blocks[0].attn.hook_z) with HookedSAEs and vice versa + + Args: + obj: Any object. In practice, this is a HookedTransformer (or subclass) + path: str. The path to the attribute you want to access. (eg "blocks.0.attn.hook_z") + value: Any. The value you want to set the attribute to (eg a HookedSAE object) + """ + parts = path.split(".") + # Navigate to the last component in the path + for part in parts[:-1]: + if part.isdigit(): # This is a list index + obj = obj[int(part)] + else: # This is an attribute + obj = getattr(obj, part) + # Set the value on the final attribute + setattr(obj, parts[-1], value) + + +class HookedSAETransformer(HookedTransformer): + def __init__( + self, + *model_args, + **model_kwargs, + ): + """Model initialization. Just HookedTransformer init, but adds a dictionary to keep track of attached SAEs. + + Note that if you want to load the model from pretrained weights, you should use + :meth:`from_pretrained` instead. + + Args: + *model_args: Positional arguments for HookedTransformer initialization + **model_kwargs: Keyword arguments for HookedTransformer initialization + """ + super().__init__(*model_args, **model_kwargs) + self.acts_to_saes: Dict[str, HookedSAE] = {} + + def add_sae(self, sae: HookedSAE): + """Attaches an SAE to the model + + WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point. + + Args: + sae: HookedSAE. The SAE to attach to the model + """ + act_name = sae.cfg.hook_name + if (act_name not in self.acts_to_saes) and (act_name not in self.hook_dict): + logging.warning( + f"No hook found for {act_name}. Skipping. Check model.hook_dict for available hooks." + ) + return + + self.acts_to_saes[act_name] = sae + set_deep_attr(self, act_name, sae) + self.setup() + + def _reset_sae(self, act_name: str, prev_sae: Optional[HookedSAE] = None): + """Resets an SAE that was attached to the model + + By default will remove the SAE from that hook_point. + If prev_sae is provided, will replace the current SAE with the provided one. + This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes) + + Args: + act_name: str. The hook_name of the SAE to reset + prev_sae: Optional[HookedSAE]. The SAE to replace the current one with. If None, will just remove the SAE from this hook point. Defaults to None + """ + if act_name not in self.acts_to_saes: + logging.warning(f"No SAE is attached to {act_name}. There's nothing to reset.") + return + + if prev_sae: + set_deep_attr(self, act_name, prev_sae) + self.acts_to_saes[act_name] = prev_sae + else: + set_deep_attr(self, act_name, HookPoint()) + del self.acts_to_saes[act_name] + + def reset_saes( + self, + act_names: Optional[Union[str, List[str]]] = None, + prev_saes: Optional[List[Union[HookedSAE, None]]] = None, + ): + """Reset the SAEs attached to the model + + If act_names are provided will just reset SAEs attached to those hooks. Otherwise will reset all SAEs attached to the model. + Optionally can provide a list of prev_saes to reset to. This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes). + + Args: + act_names (Optional[Union[str, List[str]]): The act_names of the SAEs to reset. If None, will reset all SAEs attached to the model. Defaults to None. + prev_saes (Optional[List[Union[HookedSAE, None]]]): List of SAEs to replace the current ones with. If None, will just remove the SAEs. Defaults to None. + """ + if isinstance(act_names, str): + act_names = [act_names] + elif act_names is None: + act_names = list(self.acts_to_saes.keys()) + + if prev_saes: + assert len(act_names) == len( + prev_saes + ), "act_names and prev_saes must have the same length" + else: + prev_saes = [None] * len(act_names) + + for act_name, prev_sae in zip(act_names, prev_saes): + self._reset_sae(act_name, prev_sae) + + self.setup() + + def run_with_saes( + self, + *model_args, + saes: Union[HookedSAE, List[HookedSAE]] = [], + reset_saes_end: bool = True, + **model_kwargs, + ) -> Union[ + None, + Float[torch.Tensor, "batch pos d_vocab"], + Loss, + Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss], + ]: + """Wrapper around HookedTransformer forward pass. + + Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after. + + Args: + *model_args: Positional arguments for the model forward pass + saes: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass + reset_saes_end (bool): If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True. + **model_kwargs: Keyword arguments for the model forward pass + """ + with self.saes(saes=saes, reset_saes_end=reset_saes_end): + return self(*model_args, **model_kwargs) + + def run_with_cache_with_saes( + self, + *model_args, + saes: Union[HookedSAE, List[HookedSAE]] = [], + reset_saes_end: bool = True, + return_cache_object=True, + remove_batch_dim=False, + **kwargs, + ) -> Tuple[ + Union[ + None, + Float[torch.Tensor, "batch pos d_vocab"], + Loss, + Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss], + ], + Union[ActivationCache, Dict[str, torch.Tensor]], + ]: + """Wrapper around 'run_with_cache' in HookedTransformer. + + Attaches given SAEs before running the model with cache and then removes them. + By default, will reset all SAEs to original state after. + + Args: + *model_args: Positional arguments for the model forward pass + saes: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass + reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True. + return_cache_object: (bool) if True, this will return an ActivationCache object, with a bunch of + useful HookedTransformer specific methods, otherwise it will return a dictionary of + activations as in HookedRootModule. + remove_batch_dim: (bool) Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. + **kwargs: Keyword arguments for the model forward pass + """ + with self.saes(saes=saes, reset_saes_end=reset_saes_end): + return self.run_with_cache( + *model_args, + return_cache_object=return_cache_object, + remove_batch_dim=remove_batch_dim, + **kwargs, + ) + + def run_with_hooks_with_saes( + self, + *model_args, + saes: Union[HookedSAE, List[HookedSAE]] = [], + reset_saes_end: bool = True, + fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], + bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], + reset_hooks_end=True, + clear_contexts=False, + **model_kwargs, + ): + """Wrapper around 'run_with_hooks' in HookedTransformer. + + Attaches the given SAEs to the model before running the model with hooks and then removes them. + By default, will reset all SAEs to original state after. + + Args: + *model_args: Positional arguments for the model forward pass + act_names: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass + reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. (default: True) + fwd_hooks: (List[Tuple[Union[str, Callable], Callable]]) List of forward hooks to apply + bwd_hooks: (List[Tuple[Union[str, Callable], Callable]]) List of backward hooks to apply + reset_hooks_end: (bool) Whether to reset the hooks at the end of the forward pass (default: True) + clear_contexts: (bool) Whether to clear the contexts at the end of the forward pass (default: False) + **model_kwargs: Keyword arguments for the model forward pass + """ + with self.saes(saes=saes, reset_saes_end=reset_saes_end): + return self.run_with_hooks( + *model_args, + fwd_hooks=fwd_hooks, + bwd_hooks=bwd_hooks, + reset_hooks_end=reset_hooks_end, + clear_contexts=clear_contexts, + **model_kwargs, + ) + + @contextmanager + def saes( + self, + saes: Union[HookedSAE, List[HookedSAE]] = [], + reset_saes_end: bool = True, + ): + """ + A context manager for adding temporary SAEs to the model. + See HookedTransformer.hooks for a similar context manager for hooks. + By default will keep track of previously attached SAEs, and restore them when the context manager exits. + + Example: + + .. code-block:: python + + from transformer_lens import HookedSAETransformer, HookedSAE, HookedSAEConfig + + model = HookedSAETransformer.from_pretrained('gpt2-small') + sae_cfg = HookedSAEConfig(...) + sae = HookedSAE(sae_cfg) + with model.saes(saes=[sae]): + spliced_logits = model(text) + + + Args: + saes (Union[HookedSAE, List[HookedSAE]]): SAEs to be attached. + reset_saes_end (bool): If True, removes all SAEs added by this context manager when the context manager exits, returning previously attached SAEs to their original state. + """ + act_names_to_reset = [] + prev_saes = [] + if isinstance(saes, HookedSAE): + saes = [saes] + try: + for sae in saes: + act_names_to_reset.append(sae.cfg.hook_name) + prev_saes.append(self.acts_to_saes.get(sae.cfg.hook_name, None)) + self.add_sae(sae) + yield self + finally: + if reset_saes_end: + self.reset_saes(act_names_to_reset, prev_saes) diff --git a/transformer_lens/__init__.py b/transformer_lens/__init__.py index e2fb1484b..9ab2acea7 100644 --- a/transformer_lens/__init__.py +++ b/transformer_lens/__init__.py @@ -10,6 +10,9 @@ from .FactoredMatrix import FactoredMatrix from .ActivationCache import ActivationCache from .HookedTransformer import HookedTransformer +from .HookedSAEConfig import HookedSAEConfig +from .HookedSAE import HookedSAE +from .HookedSAETransformer import HookedSAETransformer from .SVDInterpreter import SVDInterpreter from .HookedEncoder import HookedEncoder from . import head_detector From 6293e86ada4f86f00defe15a03491673ac56b463 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 1 May 2024 00:50:21 +0200 Subject: [PATCH 73/73] reworked CI to publish code coverage report (#559) * reworked CI to publish code coverage report * added coverage report to docs * added support for python 3.12 and removed extra steps on legacy versions of python * moved main check back to python 3.11 * removed coverage flag * moved download command * fixed name * specified file name * removed link --- .github/workflows/checks.yml | 100 ++++++++++++++++++++++++++++----- .github/workflows/gh-pages.yml | 63 --------------------- makefile | 7 ++- 3 files changed, 90 insertions(+), 80 deletions(-) delete mode 100644 .github/workflows/gh-pages.yml diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 09b11e966..819cb6539 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -36,8 +36,8 @@ permissions: contents: write jobs: - checks: - name: Checks + compatibility-checks: + name: Compatibility Checks runs-on: ubuntu-latest strategy: matrix: @@ -45,7 +45,6 @@ jobs: - "3.8" - "3.9" - "3.10" - - "3.11" steps: - uses: actions/checkout@v3 - name: Install Poetry @@ -67,20 +66,20 @@ jobs: run: | poetry lock --check poetry install --with dev - - name: Check format - run: make check-format - - name: Unit test + - name: Unit Test run: make unit-test - - name: Docstring test - run: make docstring-test - - name: Type check - run: poetry run mypy . + - name: Acceptance Test + run: make acceptance-test - name: Build check run: poetry build + - name: Upload Coverage Report Artifact + uses: actions/upload-artifact@v3 + with: + name: documentation + path: htmlcov - # Acceptance tests are run in parallel with unit checks. - acceptance-tests: - name: Acceptance Tests + code-checks: + name: Code Checks runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 @@ -103,8 +102,21 @@ jobs: run: | poetry lock --check poetry install --with dev - - name: Acceptance test - run: make acceptance-test + - name: Check format + run: make check-format + - name: Docstring test + run: make docstring-test + - name: Type check + run: poetry run mypy . + - name: Test Suite with Coverage Report + run: make coverage-report-test + - name: Build check + run: poetry build + - name: Upload Coverage Report Artifact + uses: actions/upload-artifact@v3 + with: + name: test-coverage + path: htmlcov notebook-checks: name: Notebook Checks @@ -135,3 +147,61 @@ jobs: - name: Check Notebook Output Consistency # Note: currently only checks notebooks we have specifically setup for this run: make notebook-test + + + build-docs: + # When running on a PR, this just checks we can build the docs without errors + # When running on merge to main, it builds the docs and then another job deploys them + name: ${{ github.event_name == 'pull_request' && 'Check Build Docs' || 'Build Docs' }} + runs-on: ubuntu-latest + needs: code-checks + steps: + - uses: actions/checkout@v4 + - name: Install Poetry + uses: snok/install-poetry@v1 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + cache: "poetry" + - name: Install pandoc + uses: awalsh128/cache-apt-pkgs-action@latest + with: + packages: pandoc + version: 1.0 + - name: Install dependencies + run: poetry install --with docs + - name: Download Test Coverage Artifact + uses: actions/download-artifact@v3 + with: + name: test-coverage + path: docs/source/coverage + - name: Build Docs + run: HF_TOKEN="$HF_TOKEN" poetry run build-docs + env: + HF_TOKEN: "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" + - name: Upload Docs Artifact + uses: actions/upload-artifact@v3 + with: + name: documentation + path: docs/build + + deploy-docs: + name: Deploy Docs + runs-on: ubuntu-latest + # Only run if merging a PR into main + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + needs: build-docs + steps: + - uses: actions/checkout@v4 + - name: Download Docs Artifact + uses: actions/download-artifact@v3 + with: + name: documentation + path: docs/build + - name: Upload to GitHub Pages + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: docs/build + clean-exclude: | + *.*.*/ \ No newline at end of file diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml deleted file mode 100644 index 220dfc093..000000000 --- a/.github/workflows/gh-pages.yml +++ /dev/null @@ -1,63 +0,0 @@ -name: Docs -on: - push: - branches: - - main - pull_request: - branches: - - '*' - -permissions: - contents: write - -jobs: - build-docs: - # When running on a PR, this just checks we can build the docs without errors - # When running on merge to main, it builds the docs and then another job deploys them - name: ${{ github.event_name == 'pull_request' && 'Check Build Docs' || 'Build Docs' }} - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Install Poetry - uses: snok/install-poetry@v1 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: "3.11" - cache: "poetry" - - name: Install pandoc - uses: awalsh128/cache-apt-pkgs-action@latest - with: - packages: pandoc - version: 1.0 - - name: Install dependencies - run: poetry install --with docs - - name: Build Docs - run: HF_TOKEN="$HF_TOKEN" poetry run build-docs - env: - HF_TOKEN: "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" - - name: Upload Docs Artifact - uses: actions/upload-artifact@v3 - with: - name: documentation - path: docs/build - - deploy-docs: - name: Deploy Docs - runs-on: ubuntu-latest - # Only run if merging a PR into main - if: github.event_name == 'push' && github.ref == 'refs/heads/main' - needs: build-docs - steps: - - uses: actions/checkout@v4 - - name: Download Docs Artifact - uses: actions/download-artifact@v3 - with: - name: documentation - path: docs/build - - name: Upload to GitHub Pages - uses: JamesIves/github-pages-deploy-action@v4 - with: - folder: docs/build - clean-exclude: | - *.*.*/ diff --git a/makefile b/makefile index 17d583dae..4cc4633dc 100644 --- a/makefile +++ b/makefile @@ -9,10 +9,13 @@ check-format: poetry run black --check . unit-test: - poetry run pytest --cov=transformer_lens/ --cov-report=term-missing --cov-branch tests/unit + poetry run pytest tests/unit acceptance-test: - poetry run pytest --cov=transformer_lens/ --cov-report=term-missing --cov-branch tests/acceptance + poetry run pytest tests/acceptance + +coverage-report-test: + poetry run pytest --cov=transformer_lens/ --cov-report=html --cov-branch tests/unit tests/acceptance docstring-test: poetry run pytest transformer_lens/