diff --git a/benchmarks/evals/e2e/main.py b/benchmarks/evals/e2e/main.py index 410b07a..679c158 100644 --- a/benchmarks/evals/e2e/main.py +++ b/benchmarks/evals/e2e/main.py @@ -62,6 +62,11 @@ class EvalConfigs: "quest-256", "quest-512", "quest-1024", + "quest_optimized-64", + "quest_optimized-128", + "quest_optimized-256", + "quest_optimized-512", + "quest_optimized-1024", "raas-64", "raas-128", "raas-256", @@ -182,15 +187,18 @@ def load_model_for_approach(self, model_name: str, approach_name: str) -> AutoMo model_config = self.configs.model_config if model_config.model_type == "llama": - from transformers import LlamaForCausalLM + + optimized = ("optimized" in approach_name) if approach_name == "full" or "sink" in approach_name: # They differ only in cache type + from transformers import LlamaForCausalLM model = LlamaForCausalLM.from_pretrained( model_name, device_map="cuda:0", trust_remote_code=True, ) elif "h2o" in approach_name: + from transformers import LlamaForCausalLM from quest.models.h2o_llama import enable_h2o_attention_eval model = LlamaForCausalLM.from_pretrained( @@ -202,9 +210,25 @@ def load_model_for_approach(self, model_name: str, approach_name: str) -> AutoMo model, {"cache_budget": int(approach_name.split("-")[-1])}, ) - elif "quest" in approach_name: + elif "quest" in approach_name and optimized: + from quest.models.quest_llama_optimized import LlamaForCausalLM + from quest.models.quest_llama_optimized import enable_quest_attention_eval + model = LlamaForCausalLM.from_pretrained( + model_name, + device_map="cuda:0", + trust_remote_code=True, + torch_dtype=torch.float16, # Use float16 for optimized version + ) + enable_quest_attention_eval( + model, + { + "cache_budget": int(approach_name.split("-")[-1]), + "page_size": 16, # Fixed as stated in the paper + }, + ) + elif "quest" in approach_name and not optimized: + from transformers import LlamaForCausalLM from quest.models.quest_llama import enable_quest_attention_eval - model = LlamaForCausalLM.from_pretrained( model_name, device_map="cuda:0", @@ -218,6 +242,7 @@ def load_model_for_approach(self, model_name: str, approach_name: str) -> AutoMo }, ) elif "raas" in approach_name: + from transformers import LlamaForCausalLM from quest.models.raas_llama import enable_raas_attention_eval model = LlamaForCausalLM.from_pretrained( @@ -353,7 +378,6 @@ def test_model( cache_budget = int(self.configs.approach.split("-")[-1]) past_key_values = RaaSCache(page_size=16, cache_budget=cache_budget) - with torch.no_grad(): # Prefill @@ -403,6 +427,8 @@ def test_model( JCT = prefill_time + np.sum(decode_time) TPOT = np.sum(decode_time) / num_decode + if "optimized" in self.configs.approach: + pipe.model.reset_model() model_output = pipe.tokenizer.decode(generated_content, skip_special_tokens=True) return model_output, TTFT, JCT, TPOT, num_decode diff --git a/quest/models/full_llama.py b/quest/models/full_llama.py new file mode 100644 index 0000000..b48100f --- /dev/null +++ b/quest/models/full_llama.py @@ -0,0 +1,1474 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_llama import LlamaConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf" +_CONFIG_FOR_DOC = "LlamaConfig" + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[LlamaConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, *args, **kwargs): + logger.warning_once( + "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " + "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." + ) + kwargs["rope_type"] = "linear" + super().__init__(*args, **kwargs) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, *args, **kwargs): + logger.warning_once( + "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " + "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " + "__init__)." + ) + kwargs["rope_type"] = "dynamic" + super().__init__(*args, **kwargs) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/quest/models/quest_llama_optimized.py b/quest/models/quest_llama_optimized.py index 7cc139b..d608981 100644 --- a/quest/models/quest_llama_optimized.py +++ b/quest/models/quest_llama_optimized.py @@ -1,998 +1,1011 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Based on HuggingFace Llama Model: models/llama/modeling_llama.py -# transformers==4.31.0 #TODO - -""" PyTorch LLaMA model.""" -import math +# transformers==4.47.0 from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, + BaseModelOutputWithPast, + CausalLMOutputWithPast, ) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_utils import PreTrainedModel -from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, ) +from transformers.models.llama.configuration_llama import LlamaConfig + +import quest.quest_utils as quest_utils +from quest.quest_utils.controller import InferenceController -import quest.utils -from quest.utils import rms_norm_forward -from quest.utils.controller import InferenceController logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf" _CONFIG_FOR_DOC = "LlamaConfig" -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat( - [torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1 - ) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - return rms_norm_forward(hidden_states, self.weight, self.variance_epsilon) - - -class LlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype(), - ) + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + return quest_utils.rms_norm_forward(hidden_states, self.weight, self.variance_epsilon) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[LlamaConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from + (seq_len, num_key_value_heads, head_dim) to (seqlen, num_attention_heads, head_dim) + """ + seq_len, num_key_value_heads, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :].expand(seq_len, num_key_value_heads, n_rep, head_dim) + return hidden_states.reshape(seq_len, num_key_value_heads * n_rep, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.rope_scaling = config.rope_scaling + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + iController: Optional[InferenceController] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + assert bsz == 1, "QuestAttention only supports batch size 1." + ori_dtype = hidden_states.dtype + query_states = self.q_proj(hidden_states).to(torch.float16) + key_states = self.k_proj(hidden_states).to(torch.float16) + value_states = self.v_proj(hidden_states).to(torch.float16) - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + # Not transposed for Append kv cache NHD layout + query_states = query_states.view(q_len, self.num_heads, self.head_dim) + key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) + # Hack for GQA: we need to repeat the key and value states to match the number of heads -class LlamaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.pretraining_tp = config.pretraining_tp - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - if self.pretraining_tp > 1: - slice = self.intermediate_size // self.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat( - [F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1 - ) - up_proj = torch.cat( - [F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1 - ) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [ - F.linear(intermediate_states[i], down_proj_slices[i]) - for i in range(self.pretraining_tp) - ] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - return down_proj - - -class QuestAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig, layer_idx: int): - super().__init__() - self.layer_idx = layer_idx - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.pretraining_tp = config.pretraining_tp - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.v_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self._init_rope() - - def _init_rope(self): - # rope_theta is default to 1e4, as set in RoPE kernel API. - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings - ) - self.rope_scale = 1.0 - else: - scaling_type = self.config.rope_scaling["type"] - if scaling_type == "linear": - # support for Longchat-v1.5. - self.rope_scale = self.config.rope_scaling["factor"] - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - iController: Optional[quest.utils.InferenceController] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - assert bsz == 1, "QuestAttention only supports batch size 1." - assert hasattr(self, "layer_idx"), "QuestAttention requires layer_idx to inference." - - if self.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - torch.cuda.nvtx.range_push("qkv_proj") - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - torch.cuda.nvtx.range_pop() - - # Not transposed for Append kv cache NHD layout - query_states = query_states.view(q_len, self.num_heads, self.head_dim) - key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - torch.cuda.nvtx.range_push("RoPE") - quest.utils.apply_rope_in_place( + quest_utils.apply_rope_in_place( query_states, key_states, iController.kv_cache.seqlen - q_len, - rope_scale=self.rope_scale, + rope_scale=self.rope_scaling, + rope_theta=self.rope_theta, ) - torch.cuda.nvtx.range_pop() - torch.cuda.nvtx.range_push("append_kv") + # Quest manages KV-Cache internal (with PageAttention) # Here we do not concat / stack # We concat after RoPE - quest.utils.append_kv( + quest_utils.append_kv( key_states, value_states, iController, self.layer_idx, ) - torch.cuda.nvtx.range_pop() - - # Prefill/Decode kernels is different - if q_len > 1: - torch.cuda.nvtx.range_push("prefill_attn") - attn_output = quest.utils.prefill_forward( - query_states, - iController, - self.layer_idx, - ) - torch.cuda.nvtx.range_pop() - else: - # Skipping layers is controled by PAGE_BUDGET, which is set in LlamaModel. - if iController.need_estimate() == False: - torch.cuda.nvtx.range_push("full_attn") - attn_output = quest.utils.decode_sparse_attn( - query_states, - iController, - self.layer_idx, - iController.kv_indices_without_last, - ) - torch.cuda.nvtx.range_pop() - else: - torch.cuda.nvtx.range_push("estimate") - estimated_attn_score = quest.utils.decode_estimate( - query_states, - iController, - self.layer_idx, - ) - torch.cuda.nvtx.range_pop() - - torch.cuda.nvtx.range_push("topk") - quest.utils.decode_topk( - estimated_attn_score, - iController, - ) - torch.cuda.nvtx.range_pop() - - torch.cuda.nvtx.range_push("approx_attn") - attn_output = quest.utils.decode_sparse_attn( - query_states, - iController, - self.layer_idx, - iController.topk_dindices_buffer, - ) - torch.cuda.nvtx.range_pop() - - attn_output = attn_output.unsqueeze(0) # unsqueeze the batch dimension + + if q_len > 1: + attn_output = quest_utils.prefill_forward( + query_states, + iController, + self.layer_idx, + ) + else: + # Skipping layers is controled by PAGE_BUDGET, which is set in LlamaModel. + if not iController.need_estimate(): + attn_output = quest_utils.decode_sparse_attn( + query_states, + iController, + self.layer_idx, + iController.kv_indices_without_last, + ) + else: + estimated_attn_score = quest_utils.decode_estimate( + query_states, + iController, + self.layer_idx, + ) + + quest_utils.decode_topk( + estimated_attn_score, + iController, + ) + + attn_output = quest_utils.decode_sparse_attn( + query_states, + iController, + self.layer_idx, + iController.topk_dindices_buffer, + ) + + attn_output = attn_output.unsqueeze(0) # unsqueeze the batch dimension + # FlashInfer output is naturally NHD # Note that we manully control NHD. Should be more general - if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - torch.cuda.nvtx.range_push("o_proj") - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) - attn_output = sum( - [F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)] - ) - else: - attn_output = self.o_proj(attn_output) - torch.cuda.nvtx.range_pop() - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + attn_output = attn_output.reshape(bsz, q_len, -1).to(ori_dtype) -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = QuestAttention(config=config, layer_idx=layer_idx) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - iController: Optional[InferenceController] = None, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - torch.cuda.nvtx.range_push("input_norm") - hidden_states = self.input_layernorm(hidden_states) - torch.cuda.nvtx.range_pop() - - torch.cuda.nvtx.range_push("LlamaAttention") - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - iController=iController, - ) - torch.cuda.nvtx.range_pop() - hidden_states = residual + hidden_states + attn_output = self.o_proj(attn_output) - # Fully Connected - residual = hidden_states - torch.cuda.nvtx.range_push("norm") - hidden_states = self.post_attention_layernorm(hidden_states) - torch.cuda.nvtx.range_pop() + if not output_attentions: + attn_weights = None - torch.cuda.nvtx.range_push("mlp") - hidden_states = self.mlp(hidden_states) - torch.cuda.nvtx.range_pop() + return attn_output, attn_weights, past_key_value - hidden_states = residual + hidden_states - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + iController: Optional[InferenceController] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + iController=iController, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states - return outputs + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs LLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`LlamaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, ) class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): - module.gradient_checkpointing = value + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() LLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ @add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, ) class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - Args: - config: LlamaConfig - """ + Args: + config: LlamaConfig + """ - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [LlamaDecoderLayer(config, i) for i in range(config.num_hidden_layers)] - ) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) - self.gradient_checkpointing = False + self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Leave Quest controller as uninitialized - self.iController = None - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length - ): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ).to(inputs_embeds.device) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - seq_length_with_past = seq_length - past_key_values_length = 0 - - # KV-Cache is managed by iController - # if past_key_values is not None: - # past_key_values_length = past_key_values[0][0].shape[2] - # seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - torch.cuda.nvtx.range_push(f"embed") - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - torch.cuda.nvtx.range_pop() - - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - # Configure Quest Controller - # Prepare indices/indptr for newly appended tokens - assert self.iController is not None, "Please init Quest Controller first." - self.iController.prepare_metadata(seq_length) - - # Skip layers by setting infinite budgets - if self._quest_skip_layer > 0: - self.iController.set_page_budget(self._quest_max_page_limit) - self.iController.begin_forward(seq_length) - - for idx, decoder_layer in enumerate(self.layers): - # Configure regular skipping layers - if idx == self._quest_skip_layer: - self.iController.end_forward() - self.iController.set_page_budget(self._quest_page_budget) - # Avoid the redundant init/copy of metadata - # if previous skip layer does, then skip it again - self.iController.begin_forward(seq_length, updateTensor=(idx == 0)) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - # past_key_value = past_key_values[idx] if past_key_values is not None else None - # KV-Cache Managed by ourselves - past_key_value = None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - torch.cuda.nvtx.range_push(f"layer={idx}") - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - iController=self.iController, - ) - torch.cuda.nvtx.range_pop() - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - # Configure Quest Controller - self.iController.end_forward() - - torch.cuda.nvtx.range_push("lastnorm") - hidden_states = self.norm(hidden_states) - torch.cuda.nvtx.range_pop() - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaForCausalLM(LlamaPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - self.pretraining_tp = config.pretraining_tp - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - self._config = config # saved for quest init - # Initialize weights and apply final processing - self.post_init() - - def quest_init( - self, - page_size: int, - max_seq_len: int, - token_budget: int = 512, - dtype: torch.dtype = torch.float16, - device=torch.device("cuda:0"), - ): - """ - Init function for Quest. Must be called before forwarding. - This function allocates all GPU memory for max_seq_len KV-Cache. - """ - assert self.model.iController is None, "Can't init Quest Controller twice." - - config = self._config - self.model._quest_page_size = page_size - self.model._quest_page_budget = token_budget // page_size # default page budget - self.model._quest_max_page_limit = 1024 * 1024 # arbitraty large size - self.model._quest_skip_layer = 2 - - self.model.iController = InferenceController( - num_layers=config.num_hidden_layers, - num_heads=config.num_attention_heads, - head_dim=config.hidden_size // config.num_attention_heads, - page_size=page_size, - page_budget=self.model._quest_page_budget, - max_seq_len=max_seq_len, # Used for allocating KV Pools - dtype=dtype, - device=device, - ) - - print(f"Quest allocates KV-Cache of {max_seq_len} tokens") - print(f"Token budget is set to {token_budget}") - - def reset_model(self): - """ - Assistant function for cleaning states of KV-Cache, - which prepares for a new conversation. - """ - assert self.model.iController is not None, "Must be called after init." - self.model.iController.clean_states() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - torch.cuda.nvtx.range_push("LlamaForCausalLM") - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - torch.cuda.nvtx.range_push("lm_head") - hidden_states = outputs[0] - if self.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split( - self.vocab_size // self.pretraining_tp, dim=0 - ) - logits = [ - F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp) - ] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() - - torch.cuda.nvtx.range_pop() - torch.cuda.nvtx.range_pop() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past - ), - ) - return reordered_past + self.iController: Optional[InferenceController] = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + seq_length = inputs_embeds.shape[1] + + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + # Configure Quest Controller + # Prepare indices/indptr for newly appended tokens + assert self.iController is not None, "Please init Quest Controller first." + self.iController.prepare_metadata(seq_length) + + # Skip layers by setting infinite budgets + if self._quest_skip_layer > 0: + self.iController.set_page_budget(self._quest_max_page_limit) + self.iController.begin_forward(seq_length) + + for idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + # Configure regular skipping layers + if idx == self._quest_skip_layer: + self.iController.end_forward() + self.iController.set_page_budget(self._quest_page_budget) + # Avoid the redundant init/copy of metadata + # if previous skip layer does, then skip it again + self.iController.begin_forward(seq_length, updateTensor=(idx == 0)) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + iController=self.iController, + ) + + hidden_states = layer_outputs[0] + + # if use_cache: + # next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + self.iController.end_forward() + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + # if return_legacy_cache: + # next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self._config = config # saved for quest init + + # Initialize weights and apply final processing + self.post_init() + + def quest_init( + self, + page_size: int, + max_seq_len: int, + token_budget: int = 512, + dtype: torch.dtype = torch.float16, + device=torch.device("cuda:0"), + ): + """ + Init function for Quest. Must be called before forwarding. + This function allocates all GPU memory for max_seq_len KV-Cache. + """ + assert self.model.iController is None, "Can't init Quest Controller twice." + + config = self._config + self.model._quest_page_size = page_size + self.model._quest_page_budget = token_budget // page_size # default page budget + self.model._quest_max_page_limit = 1024 * 1024 # arbitraty large size + self.model._quest_skip_layer = 2 + + self.model.iController = InferenceController( + num_layers=config.num_hidden_layers, + num_heads=config.num_attention_heads, + head_dim=config.hidden_size // config.num_attention_heads, + page_size=page_size, + page_budget=self.model._quest_page_budget, + max_seq_len=max_seq_len, # Used for allocating KV Pools + dtype=dtype, + device=device, + ) + + print(f"Quest allocates KV-Cache of {max_seq_len} tokens") + print(f"Token budget is set to {token_budget}") + + def reset_model(self): + """ + Assistant function for cleaning states of KV-Cache, + which prepares for a new conversation. + """ + assert self.model.iController is not None, "Must be called after init." + self.model.iController.clean_states() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + use_cache = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +def enable_quest_attention_eval(model, args): + cache_budget = args["cache_budget"] + page_size = args["page_size"] + max_seq_len = args.get("max_seq_len", model.config.max_position_embeddings) + dtype = args.get("dtype", torch.float16) + device = args.get("device", torch.device("cuda:0")) + model.quest_init(page_size, max_seq_len, cache_budget, dtype, device) \ No newline at end of file diff --git a/quest/models/quest_llama_optimized_bak.py b/quest/models/quest_llama_optimized_bak.py new file mode 100644 index 0000000..5011cd1 --- /dev/null +++ b/quest/models/quest_llama_optimized_bak.py @@ -0,0 +1,998 @@ +# Based on HuggingFace Llama Model: models/llama/modeling_llama.py +# transformers==4.31.0 #TODO + +""" PyTorch LLaMA model.""" +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) + +import quest.quest_utils +from quest.quest_utils import rms_norm_forward +from quest.quest_utils.controller import InferenceController + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1 + ) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + return rms_norm_forward(hidden_states, self.weight, self.variance_epsilon) + + +class LlamaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.pretraining_tp = config.pretraining_tp + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.pretraining_tp > 1: + slice = self.intermediate_size // self.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat( + [F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1 + ) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) + for i in range(self.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +class QuestAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.pretraining_tp = config.pretraining_tp + self.max_position_embeddings = config.max_position_embeddings + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self._init_rope() + + def _init_rope(self): + # rope_theta is default to 1e4, as set in RoPE kernel API. + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings + ) + self.rope_scale = 1.0 + else: + scaling_type = self.config.rope_scaling["type"] + if scaling_type == "linear": + # support for Longchat-v1.5. + self.rope_scale = self.config.rope_scaling["factor"] + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + iController: Optional[quest.quest_utils.InferenceController] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + assert bsz == 1, "QuestAttention only supports batch size 1." + assert hasattr(self, "layer_idx"), "QuestAttention requires layer_idx to inference." + + if self.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [ + F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) + ] + query_states = torch.cat(query_states, dim=-1) + + key_states = [ + F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp) + ] + key_states = torch.cat(key_states, dim=-1) + + value_states = [ + F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp) + ] + value_states = torch.cat(value_states, dim=-1) + + else: + torch.cuda.nvtx.range_push("qkv_proj") + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + torch.cuda.nvtx.range_pop() + + # Not transposed for Append kv cache NHD layout + query_states = query_states.view(q_len, self.num_heads, self.head_dim) + key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) + + torch.cuda.nvtx.range_push("RoPE") + quest.quest_utils.apply_rope_in_place( + query_states, + key_states, + iController.kv_cache.seqlen - q_len, + rope_scale=self.rope_scale, + ) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("append_kv") + # Quest manages KV-Cache internal (with PageAttention) + # Here we do not concat / stack + # We concat after RoPE + quest.quest_utils.append_kv( + key_states, + value_states, + iController, + self.layer_idx, + ) + torch.cuda.nvtx.range_pop() + + # Prefill/Decode kernels is different + if q_len > 1: + torch.cuda.nvtx.range_push("prefill_attn") + attn_output = quest.quest_utils.prefill_forward( + query_states, + iController, + self.layer_idx, + ) + torch.cuda.nvtx.range_pop() + else: + # Skipping layers is controled by PAGE_BUDGET, which is set in LlamaModel. + if iController.need_estimate() == False: + torch.cuda.nvtx.range_push("full_attn") + attn_output = quest.quest_utils.decode_sparse_attn( + query_states, + iController, + self.layer_idx, + iController.kv_indices_without_last, + ) + torch.cuda.nvtx.range_pop() + else: + torch.cuda.nvtx.range_push("estimate") + estimated_attn_score = quest.quest_utils.decode_estimate( + query_states, + iController, + self.layer_idx, + ) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("topk") + quest.quest_utils.decode_topk( + estimated_attn_score, + iController, + ) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("approx_attn") + attn_output = quest.quest_utils.decode_sparse_attn( + query_states, + iController, + self.layer_idx, + iController.topk_dindices_buffer, + ) + torch.cuda.nvtx.range_pop() + + attn_output = attn_output.unsqueeze(0) # unsqueeze the batch dimension + # FlashInfer output is naturally NHD + # Note that we manully control NHD. Should be more general + if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + torch.cuda.nvtx.range_push("o_proj") + if self.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) + attn_output = sum( + [F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)] + ) + else: + attn_output = self.o_proj(attn_output) + torch.cuda.nvtx.range_pop() + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = QuestAttention(config=config, layer_idx=layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + iController: Optional[InferenceController] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + torch.cuda.nvtx.range_push("input_norm") + hidden_states = self.input_layernorm(hidden_states) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("LlamaAttention") + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + iController=iController, + ) + torch.cuda.nvtx.range_pop() + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + torch.cuda.nvtx.range_push("norm") + hidden_states = self.post_attention_layernorm(hidden_states) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("mlp") + hidden_states = self.mlp(hidden_states) + torch.cuda.nvtx.range_pop() + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LlamaModel): + module.gradient_checkpointing = value + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, i) for i in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + + # Leave Quest controller as uninitialized + self.iController = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # KV-Cache is managed by iController + # if past_key_values is not None: + # past_key_values_length = past_key_values[0][0].shape[2] + # seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + torch.cuda.nvtx.range_push(f"embed") + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + torch.cuda.nvtx.range_pop() + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # Configure Quest Controller + # Prepare indices/indptr for newly appended tokens + assert self.iController is not None, "Please init Quest Controller first." + self.iController.prepare_metadata(seq_length) + + # Skip layers by setting infinite budgets + if self._quest_skip_layer > 0: + self.iController.set_page_budget(self._quest_max_page_limit) + self.iController.begin_forward(seq_length) + + for idx, decoder_layer in enumerate(self.layers): + # Configure regular skipping layers + if idx == self._quest_skip_layer: + self.iController.end_forward() + self.iController.set_page_budget(self._quest_page_budget) + # Avoid the redundant init/copy of metadata + # if previous skip layer does, then skip it again + self.iController.begin_forward(seq_length, updateTensor=(idx == 0)) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # past_key_value = past_key_values[idx] if past_key_values is not None else None + # KV-Cache Managed by ourselves + past_key_value = None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + torch.cuda.nvtx.range_push(f"layer={idx}") + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + iController=self.iController, + ) + torch.cuda.nvtx.range_pop() + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # Configure Quest Controller + self.iController.end_forward() + + torch.cuda.nvtx.range_push("lastnorm") + hidden_states = self.norm(hidden_states) + torch.cuda.nvtx.range_pop() + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.pretraining_tp = config.pretraining_tp + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self._config = config # saved for quest init + # Initialize weights and apply final processing + self.post_init() + + def quest_init( + self, + page_size: int, + max_seq_len: int, + token_budget: int = 512, + dtype: torch.dtype = torch.float16, + device=torch.device("cuda:0"), + ): + """ + Init function for Quest. Must be called before forwarding. + This function allocates all GPU memory for max_seq_len KV-Cache. + """ + assert self.model.iController is None, "Can't init Quest Controller twice." + + config = self._config + self.model._quest_page_size = page_size + self.model._quest_page_budget = token_budget // page_size # default page budget + self.model._quest_max_page_limit = 1024 * 1024 # arbitraty large size + self.model._quest_skip_layer = 2 + + self.model.iController = InferenceController( + num_layers=config.num_hidden_layers, + num_heads=config.num_attention_heads, + head_dim=config.hidden_size // config.num_attention_heads, + page_size=page_size, + page_budget=self.model._quest_page_budget, + max_seq_len=max_seq_len, # Used for allocating KV Pools + dtype=dtype, + device=device, + ) + + print(f"Quest allocates KV-Cache of {max_seq_len} tokens") + print(f"Token budget is set to {token_budget}") + + def reset_model(self): + """ + Assistant function for cleaning states of KV-Cache, + which prepares for a new conversation. + """ + assert self.model.iController is not None, "Must be called after init." + self.model.iController.clean_states() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + torch.cuda.nvtx.range_push("LlamaForCausalLM") + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + torch.cuda.nvtx.range_push("lm_head") + hidden_states = outputs[0] + if self.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split( + self.vocab_size // self.pretraining_tp, dim=0 + ) + logits = [ + F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp) + ] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_pop() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past diff --git a/quest/ops/CMakeLists.txt b/quest/quest_ops/CMakeLists.txt similarity index 61% rename from quest/ops/CMakeLists.txt rename to quest/quest_ops/CMakeLists.txt index 90e3fcb..2c158e7 100644 --- a/quest/ops/CMakeLists.txt +++ b/quest/quest_ops/CMakeLists.txt @@ -20,7 +20,7 @@ include(rapids-cuda) include(rapids-export) include(rapids-find) -project(_kernels LANGUAGES CUDA CXX) # Replace with your project's name +project(_quest_kernels LANGUAGES CUDA CXX) # Replace with your project's name # ------------- configure raft -----------------# rapids_cpm_init() @@ -37,11 +37,11 @@ find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib add_subdirectory(${CMAKE_SOURCE_DIR}/../../kernels/3rdparty/pybind ${CMAKE_BINARY_DIR}/pybind11) file(GLOB PYTORCH_SOURCES "csrc/*.cu") -pybind11_add_module(_kernels MODULE ${PYTORCH_CPP_SOURCES} ${PYTORCH_SOURCES}) - -target_compile_definitions(_kernels PRIVATE -DBSK_TORCH_CHECK) # Enable Torch Tensor Dimension Check -target_include_directories(_kernels PRIVATE ${CMAKE_SOURCE_DIR}/../../kernels/include) -target_include_directories(_kernels PRIVATE ${CMAKE_SOURCE_DIR}/../../kernels/3rdparty/flashinfer/include) -target_include_directories(_kernels PRIVATE ${CMAKE_SOURCE_DIR}/../../kernels/3rdparty/pybind/include) -target_compile_options(_kernels PRIVATE $<$:--expt-extended-lambda --expt-relaxed-constexpr>) -target_link_libraries(_kernels PRIVATE ${TORCH_LIBRARIES} raft::raft Python::Python pybind11::module ${TORCH_PYTHON_LIBRARY}) \ No newline at end of file +pybind11_add_module(_quest_kernels MODULE ${PYTORCH_CPP_SOURCES} ${PYTORCH_SOURCES}) + +target_compile_definitions(_quest_kernels PRIVATE -DBSK_TORCH_CHECK) # Enable Torch Tensor Dimension Check +target_include_directories(_quest_kernels PRIVATE ${CMAKE_SOURCE_DIR}/../../kernels/include) +target_include_directories(_quest_kernels PRIVATE ${CMAKE_SOURCE_DIR}/../../kernels/3rdparty/flashinfer/include) +target_include_directories(_quest_kernels PRIVATE ${CMAKE_SOURCE_DIR}/../../kernels/3rdparty/pybind/include) +target_compile_options(_quest_kernels PRIVATE $<$:--expt-extended-lambda --expt-relaxed-constexpr>) +target_link_libraries(_quest_kernels PRIVATE ${TORCH_LIBRARIES} raft::raft Python::Python pybind11::module ${TORCH_PYTHON_LIBRARY}) \ No newline at end of file diff --git a/quest/ops/cmake/fetch_rapids.cmake b/quest/quest_ops/cmake/fetch_rapids.cmake similarity index 100% rename from quest/ops/cmake/fetch_rapids.cmake rename to quest/quest_ops/cmake/fetch_rapids.cmake diff --git a/quest/ops/cmake/get_raft.cmake b/quest/quest_ops/cmake/get_raft.cmake similarity index 100% rename from quest/ops/cmake/get_raft.cmake rename to quest/quest_ops/cmake/get_raft.cmake diff --git a/quest/ops/csrc/approx_attn.cu b/quest/quest_ops/csrc/approx_attn.cu similarity index 100% rename from quest/ops/csrc/approx_attn.cu rename to quest/quest_ops/csrc/approx_attn.cu diff --git a/quest/ops/csrc/batch_prefill.cu b/quest/quest_ops/csrc/batch_prefill.cu similarity index 100% rename from quest/ops/csrc/batch_prefill.cu rename to quest/quest_ops/csrc/batch_prefill.cu diff --git a/quest/ops/csrc/bsk_ops.cu b/quest/quest_ops/csrc/bsk_ops.cu similarity index 96% rename from quest/ops/csrc/bsk_ops.cu rename to quest/quest_ops/csrc/bsk_ops.cu index 300885f..85b553f 100644 --- a/quest/ops/csrc/bsk_ops.cu +++ b/quest/quest_ops/csrc/bsk_ops.cu @@ -1,7 +1,7 @@ #include #include "bsk_ops.h" -PYBIND11_MODULE(_kernels, m) { +PYBIND11_MODULE(_quest_kernels, m) { m.def("apply_rope_in_place", &apply_rope_in_place, "Apply RoPE on Q/K in place."); m.def("rms_norm_forward", &rms_norm_forward, "rms_norm_forward by cutlass"); m.def("topk_filtering", &topk_filtering, "Top-k filtering operator"); diff --git a/quest/ops/csrc/bsk_ops.h b/quest/quest_ops/csrc/bsk_ops.h similarity index 100% rename from quest/ops/csrc/bsk_ops.h rename to quest/quest_ops/csrc/bsk_ops.h diff --git a/quest/ops/csrc/estimate.cu b/quest/quest_ops/csrc/estimate.cu similarity index 100% rename from quest/ops/csrc/estimate.cu rename to quest/quest_ops/csrc/estimate.cu diff --git a/quest/ops/csrc/page.cu b/quest/quest_ops/csrc/page.cu similarity index 100% rename from quest/ops/csrc/page.cu rename to quest/quest_ops/csrc/page.cu diff --git a/quest/ops/csrc/pytorch_extension_utils.h b/quest/quest_ops/csrc/pytorch_extension_utils.h similarity index 100% rename from quest/ops/csrc/pytorch_extension_utils.h rename to quest/quest_ops/csrc/pytorch_extension_utils.h diff --git a/quest/ops/csrc/rms_norm.cu b/quest/quest_ops/csrc/rms_norm.cu similarity index 100% rename from quest/ops/csrc/rms_norm.cu rename to quest/quest_ops/csrc/rms_norm.cu diff --git a/quest/ops/csrc/topk.cu b/quest/quest_ops/csrc/topk.cu similarity index 100% rename from quest/ops/csrc/topk.cu rename to quest/quest_ops/csrc/topk.cu diff --git a/quest/ops/setup.sh b/quest/quest_ops/setup.sh similarity index 100% rename from quest/ops/setup.sh rename to quest/quest_ops/setup.sh diff --git a/quest/quest_utils/__init__.py b/quest/quest_utils/__init__.py new file mode 100644 index 0000000..420e6d6 --- /dev/null +++ b/quest/quest_utils/__init__.py @@ -0,0 +1,276 @@ +import torch +import math +from typing import Optional + +import quest._quest_kernels as _kernels +from quest.quest_utils.utils import TensorLayout +from quest.quest_utils.kv_cache import KvCache +from quest.quest_utils.controller import InferenceController +from quest.quest_utils.decode_wrapper import BatchDecodeWithPagedKVCacheWrapper + +__all__ = [ + 'TensorLayout', + 'KvCache', + 'InferenceController', + "BatchDecodeWithPagedKVCacheWrapper", + "append_kv", + "prefill_forward", + "decode_estimate", + "decode_topk", + "decode_sparse_attn", + "rms_norm_forward", + "apply_rope_in_place", +] + +def apply_rope_in_place( + q: torch.Tensor, + k: torch.Tensor, + past_kv_len: int, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, +): + """ + Semantics of `apply_rope_in_place`: + Apply RoPE (Relative Positional Encoding) in-place. + On q, k which is generated by GEMM. Layout is naturally NHD. + + Args: + q: Shape: `[N, H, D]`. + k: Shape: `[N, H, D]`. + past_kv_len: Length of past KV cache. Used to calculate frequency. + """ + if rope_scale is None: + rope_scale = 1.0 + if rope_theta is None: + rope_theta = 1e4 + _kernels.apply_rope_in_place( + q, + k, + past_kv_len, + rope_scale, + rope_theta, + ) + +def rms_norm_forward( + input: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> torch.Tensor: + o = torch.empty_like(input, dtype=input.dtype, device=input.device) + f = _kernels.rms_norm_forward + f( + input, + weight, + o, + epsilon, + ) + return o + +def append_kv( + k: torch.Tensor, + v: torch.Tensor, + iController: InferenceController, + layer_idx: int, +): + """ + Semantics of `append_kv`: + Append new generated k/v into kv cache and meta data cache. + Automatically dispatch to Prefill / Decode Kernel + + Notations for shapes: + `B`: batch size + `N`: number of heads + `D`: head dimension + `L`: number of layers + `MAXLEN`: maximum length of the KV cache + + Args: + k: Shape: `[B, N, D]`. Key projection (`X @ W_k`). + v: Shape: `[B, N, D]`. Value projection (`X @ W_v`). + iController: InferenceController object, which contains all needed information. + layer_idx: Layer index of the KV cache. + """ + seq_len = k.size(0) + if seq_len > 1: + _kernels.append_kv_cache_prefill( + k, + v, + iController.kv_cache.buf_layer(layer_idx), + iController.kv_indices_with_last, + iController.kv_indptr_for_append, + iController.kv_cache.last_page_len, + iController.kv_last_page_idx, + iController.metadata_cache.buf_layer(layer_idx), + iController.metadata_indices, + iController.metadata_indptr_for_append, + iController.metadata_cache.last_page_len, + iController.metadata_last_page_idx, + iController.layout + ) + else: + _kernels.append_kv_cache_decode( + k, + v, + iController.kv_cache.buf_layer(layer_idx), + iController.kv_indices_with_last, + iController.kv_indptr_for_append, + iController.kv_cache.last_page_len, + iController.kv_last_page_idx, + iController.metadata_cache.buf_layer(layer_idx), + iController.metadata_indices, + iController.metadata_indptr_for_append, + iController.metadata_cache.last_page_len, + iController.metadata_last_page_idx, + iController.layout + ) + +def prefill_forward( + q: torch.Tensor, + iController: InferenceController, + layer_idx: int, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, +) -> torch.Tensor: + """ + Semantics of `prefill_forward`: + New genrated K/Vs are already in the kv cache and meta data cache (well-maintained). + Perform FlashInfer Self-Attention with Casual Attention. + Note that we not have position shift and current version not support Prefill Optimization. + + Notations for shapes: + `B`: batch size + `N`: number of heads + `D`: head dimension + `L`: number of layers + `MAXLEN`: maximum length of the KV cache + + Args: + q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). + iController: InferenceController object, which contains all needed information. + layer_idx: Layer index of the KV cache. + """ + if rope_scale is None: + rope_scale = 1.0 + if rope_theta is None: + rope_theta = 1e4 + + f = _kernels.prefill_with_paged_kv_cache + o = f( + q, + iController.kv_cache.buf_layer(layer_idx), + iController.kv_indices_with_last, + iController.kv_cache.last_page_len, + True, # Casual + iController.layout, + False, # FP16 Accumulator for 4090 + rope_scale, + rope_theta, + ) + return o + +def decode_estimate( + q: torch.Tensor, + iController: InferenceController, + layer_idx: int, +) -> torch.Tensor: + """ + Semantics of `decode_estimate`: + When decoding, estimate the attention score for each page. + + Notations for shapes: + `B`: batch size + `N`: number of heads + `D`: head dimension + `L`: number of layers + `MAXLEN`: maximum length of the KV cache + + Args: + q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). + iController: InferenceController object, which contains all needed information. + layer_idx: Layer index of the KV cache. + """ + f = _kernels.estimate_attn_score + # (iController.metadata_cache.seqlen - 1) is manually excluding the last elements, which is the current page. + o = torch.empty((iController.num_heads, iController.metadata_cache.seqlen - 1), dtype=q.dtype, device=q.device) + f( + q, + o, + iController.metadata_cache.buf_layer(layer_idx), + iController.metadata_indices, + iController.metadata_indptr_for_append, + iController.metadata_cache.last_page_len, # One entry delta is considered by kernel-level implementation + iController.metadata_last_page_idx, + iController.layout, + ) + return o + +def decode_topk( + estimated_attn_score: torch.Tensor, + iController: InferenceController, +): + """ + Semantics of `decode_topk`: + select top-k pages with highest attention score. + + Notations for shapes: + `B`: batch size + `N`: number of heads + `D`: head dimension + `L`: number of layers + `MAXLEN`: maximum length of the KV cache + + Args: + q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). + iController: InferenceController object, which contains all needed information. + layer_idx: Layer index of the KV cache. + """ + # excluding the last page + page_budet = iController.inference_page_budget - 1 + f = _kernels.topk_filtering + f( + estimated_attn_score, + iController.kv_indices_without_last, + iController.topk_dout_buffer, + iController.topk_dindices_buffer, + iController.topk_buf, + page_budet, + ) + +def decode_sparse_attn( + q: torch.Tensor, + iController: InferenceController, + layer_idx: int, + topk_indices: torch.Tensor, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, +) -> torch.Tensor: + """ + Semantics of `decode_sparse_attn`: + Excute self-attention only on the selected pages (Top-k output) + + Notations for shapes: + `B`: batch size + `N`: number of heads + `D`: head dimension + `L`: number of layers + `MAXLEN`: maximum length of the KV cache + + Args: + q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). + iController: InferenceController object, which contains all needed information. + layer_idx: Layer index of the KV cache. + topk_indices: Shape: `[N, page_budget-1]`. Top-k indices. + """ + o = torch.empty_like(q, dtype=q.dtype, device=q.device) + iController._decode_handler.forward( + q, + o, + iController.kv_cache.buf_layer(layer_idx), + topk_indices, + iController.kv_indptr_for_approx_decode, + iController.kv_cache.last_page_len, + iController.kv_last_page_idx, + rope_scale, + rope_theta, + ) + return o \ No newline at end of file diff --git a/quest/utils/controller.py b/quest/quest_utils/controller.py similarity index 97% rename from quest/utils/controller.py rename to quest/quest_utils/controller.py index 135b8d2..0e307d7 100644 --- a/quest/utils/controller.py +++ b/quest/quest_utils/controller.py @@ -1,6 +1,6 @@ -from quest.utils.decode_wrapper import BatchDecodeWithPagedKVCacheWrapper -from quest.utils.kv_cache import KvCache -from quest.utils.utils import TensorLayout +from quest.quest_utils.decode_wrapper import BatchDecodeWithPagedKVCacheWrapper +from quest.quest_utils.kv_cache import KvCache +from quest.quest_utils.utils import TensorLayout import torch diff --git a/quest/utils/decode_wrapper.py b/quest/quest_utils/decode_wrapper.py similarity index 96% rename from quest/utils/decode_wrapper.py rename to quest/quest_utils/decode_wrapper.py index 4cec8e7..c39db2e 100644 --- a/quest/utils/decode_wrapper.py +++ b/quest/quest_utils/decode_wrapper.py @@ -1,8 +1,8 @@ import torch from typing import Optional -import quest._kernels as _kernels -from quest.utils.utils import TensorLayout +import quest._quest_kernels as _kernels +from quest.quest_utils.utils import TensorLayout def _check_kv_layout(kv_layout: str): if not hasattr(TensorLayout, kv_layout): diff --git a/quest/utils/kv_cache.py b/quest/quest_utils/kv_cache.py similarity index 98% rename from quest/utils/kv_cache.py rename to quest/quest_utils/kv_cache.py index ed4d939..d9fc835 100644 --- a/quest/utils/kv_cache.py +++ b/quest/quest_utils/kv_cache.py @@ -1,7 +1,7 @@ # This file is modified from Punica Project # Check ref: https://github.com/punica-ai/punica -from quest.utils.utils import TensorLayout +from quest.quest_utils.utils import TensorLayout import torch class KvPool: diff --git a/quest/utils/utils.py b/quest/quest_utils/utils.py similarity index 100% rename from quest/utils/utils.py rename to quest/quest_utils/utils.py diff --git a/quest/tests/test_approx_attention.py b/quest/tests/test_approx_attention.py index ba10ba8..70674c9 100644 --- a/quest/tests/test_approx_attention.py +++ b/quest/tests/test_approx_attention.py @@ -5,7 +5,7 @@ import torch.nn as nn import math -import quest.utils +import quest.quest_utils def assert_close(a, b): rtol, atol = { @@ -136,7 +136,7 @@ def test_approx_attention_correctness(dtype_str, qo_len, kv_len, page_budget): k_prefill = torch.randn(kv_len-1, num_heads, head_dim, dtype=dtype, device=device) v_prefill = torch.randn(kv_len-1, num_heads, head_dim, dtype=dtype, device=device) - testController = quest.utils.InferenceController( + testController = quest.quest_utils.InferenceController( num_layers, num_heads, head_dim, @@ -151,7 +151,7 @@ def test_approx_attention_correctness(dtype_str, qo_len, kv_len, page_budget): testController.prepare_metadata(kv_len-1) testController.begin_forward(kv_len-1) # Construct KV - quest.utils.append_kv(k_prefill, v_prefill, testController, 0) + quest.quest_utils.append_kv(k_prefill, v_prefill, testController, 0) testController.end_forward() k_decode = torch.randn(1, num_heads, head_dim, dtype=dtype, device=device) @@ -165,10 +165,10 @@ def test_approx_attention_correctness(dtype_str, qo_len, kv_len, page_budget): # CUDA Evaluation testController.prepare_metadata(qo_len) testController.begin_forward(qo_len) - quest.utils.append_kv(k_decode, v_decode, testController, 0) + quest.quest_utils.append_kv(k_decode, v_decode, testController, 0) if testController.need_estimate() == False: - o_device = quest.utils.decode_sparse_attn( + o_device = quest.quest_utils.decode_sparse_attn( q, testController, 0, @@ -188,7 +188,7 @@ def test_approx_attention_correctness(dtype_str, qo_len, kv_len, page_budget): # estimated_attn_score, # testController, # ) - o_device = quest.utils.decode_sparse_attn( + o_device = quest.quest_utils.decode_sparse_attn( q, testController, 0, diff --git a/quest/tests/test_decode_attention.py b/quest/tests/test_decode_attention.py index c8b3681..bd494ca 100644 --- a/quest/tests/test_decode_attention.py +++ b/quest/tests/test_decode_attention.py @@ -5,7 +5,7 @@ import torch.nn as nn import math -import quest.utils +import quest.quest_utils def assert_close(a, b): rtol, atol = { @@ -69,7 +69,7 @@ def test_decode_attention_correctness(dtype_str, qo_len, kv_len): k_prefill = torch.randn(kv_len-1, num_heads, head_dim, dtype=dtype, device=device) v_prefill = torch.randn(kv_len-1, num_heads, head_dim, dtype=dtype, device=device) - testController = quest.utils.InferenceController( + testController = quest.quest_utils.InferenceController( num_layers, num_heads, head_dim, @@ -84,7 +84,7 @@ def test_decode_attention_correctness(dtype_str, qo_len, kv_len): testController.prepare_metadata(kv_len-1) testController.begin_forward(kv_len-1) # Construct KV - quest.utils.append_kv(k_prefill, v_prefill, testController, 0) + quest.quest_utils.append_kv(k_prefill, v_prefill, testController, 0) testController.end_forward() k_decode = torch.randn(1, num_heads, head_dim, dtype=dtype, device=device) @@ -92,10 +92,10 @@ def test_decode_attention_correctness(dtype_str, qo_len, kv_len): # Real decoding starts testController.prepare_metadata(1) testController.begin_forward(1) - quest.utils.append_kv(k_decode, v_decode, testController, 0) + quest.quest_utils.append_kv(k_decode, v_decode, testController, 0) # No CPU test cases assert testController.need_estimate() == False - o_device = quest.utils.decode_sparse_attn( + o_device = quest.quest_utils.decode_sparse_attn( q, testController, 0, diff --git a/quest/tests/test_estimate.py b/quest/tests/test_estimate.py index 0a95fd0..1f0b29e 100644 --- a/quest/tests/test_estimate.py +++ b/quest/tests/test_estimate.py @@ -5,7 +5,7 @@ import torch.nn as nn import math -import quest.utils +import quest.quest_utils def assert_close(a, b): rtol, atol = { @@ -103,7 +103,7 @@ def test_estimate_correctness(dtype_str, kv_len): k_prefill = torch.randn(kv_len-1, num_heads, head_dim, dtype=dtype, device=device) v_prefill = torch.randn(kv_len-1, num_heads, head_dim, dtype=dtype, device=device) - testController = quest.utils.InferenceController( + testController = quest.quest_utils.InferenceController( num_layers, num_heads, head_dim, @@ -118,7 +118,7 @@ def test_estimate_correctness(dtype_str, kv_len): testController.prepare_metadata(kv_len-1) testController.begin_forward(kv_len-1) # Construct KV - quest.utils.append_kv(k_prefill, v_prefill, testController, 0) + quest.quest_utils.append_kv(k_prefill, v_prefill, testController, 0) testController.end_forward() k_decode = torch.randn(1, num_heads, head_dim, dtype=dtype, device=device) @@ -127,8 +127,8 @@ def test_estimate_correctness(dtype_str, kv_len): # CUDA Evaluation testController.prepare_metadata(qo_len) testController.begin_forward(qo_len) - quest.utils.append_kv(k_decode, v_decode, testController, 0) - cuda_estimated_value = quest.utils.decode_estimate( + quest.quest_utils.append_kv(k_decode, v_decode, testController, 0) + cuda_estimated_value = quest.quest_utils.decode_estimate( q, testController, 0, diff --git a/quest/tests/test_prefill_attention.py b/quest/tests/test_prefill_attention.py index 24f6224..5614826 100644 --- a/quest/tests/test_prefill_attention.py +++ b/quest/tests/test_prefill_attention.py @@ -5,7 +5,7 @@ import torch.nn as nn import math -import quest.utils +import quest.quest_utils def assert_close(a, b): rtol, atol = { @@ -65,7 +65,7 @@ def test_prefill_attention_correctness(dtype_str, qo_len, kv_len): k = torch.randn(kv_len, num_heads, head_dim, dtype=dtype, device=device) v = torch.randn(kv_len, num_heads, head_dim, dtype=dtype, device=device) - testController = quest.utils.InferenceController( + testController = quest.quest_utils.InferenceController( num_layers, num_heads, head_dim, @@ -80,8 +80,8 @@ def test_prefill_attention_correctness(dtype_str, qo_len, kv_len): testController.prepare_metadata(kv_len) testController.begin_forward(kv_len) # Construct KV with maintained metadata - quest.utils.append_kv(k, v, testController, 0) - o_device = quest.utils.prefill_forward(q, testController, 0) + quest.quest_utils.append_kv(k, v, testController, 0) + o_device = quest.quest_utils.prefill_forward(q, testController, 0) o_host = _ref_self_attention(q, k, v) assert_close(o_device, o_host) \ No newline at end of file diff --git a/quest/tests/test_rope.py b/quest/tests/test_rope.py index 7cda3e1..707a3e6 100644 --- a/quest/tests/test_rope.py +++ b/quest/tests/test_rope.py @@ -4,7 +4,7 @@ import torch.nn as nn from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb -import quest.utils +import quest.quest_utils def assert_close(a, b): rtol, atol = { @@ -44,7 +44,7 @@ def test_apply_qk_rope(dtype_str, past_kv_len, seq_len): k = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device=device) q_ref, k_ref = _ref_apply_qk_rope(q, k, past_kv_len) - quest.utils.apply_rope_in_place(q, k, past_kv_len) + quest.quest_utils.apply_rope_in_place(q, k, past_kv_len) assert_close(q, q_ref) assert_close(k, k_ref) \ No newline at end of file diff --git a/quest/tests/test_topk.py b/quest/tests/test_topk.py index 54a6f78..4d39718 100644 --- a/quest/tests/test_topk.py +++ b/quest/tests/test_topk.py @@ -5,7 +5,7 @@ import torch.nn as nn import math -import quest.utils +import quest.quest_utils # This file is used for testing topk kernel from libRAFT # We do not seriously compare the topk indices since the random value leads to similar tensor. @@ -50,7 +50,7 @@ def test_topk_correctness(dtype_str, kv_len, k_budget): cuda_output_indices = torch.arange(0, k_budget, dtype=torch.int32, device=device).repeat(num_heads, 1) topk_buf = torch.zeros((num_heads, 8192 * 2 * (2+4) // 2 // 48), dtype=dtype, device=device) - quest.utils._kernels.topk_filtering( + quest.quest_utils._quest_kernels.topk_filtering( cuda_input_data, cuda_input_indices, cuda_output_data, diff --git a/quest/utils/__init__.py b/quest/utils/__init__.py index d42e80e..e69de29 100644 --- a/quest/utils/__init__.py +++ b/quest/utils/__init__.py @@ -1,276 +0,0 @@ -# import torch -# import math -# from typing import Optional - -# import quest._kernels as _kernels -# from quest.utils.utils import TensorLayout -# from quest.utils.kv_cache import KvCache -# from quest.utils.controller import InferenceController -# from quest.utils.decode_wrapper import BatchDecodeWithPagedKVCacheWrapper - -# __all__ = [ -# 'TensorLayout', -# 'KvCache', -# 'InferenceController', -# "BatchDecodeWithPagedKVCacheWrapper", -# "append_kv", -# "prefill_forward", -# "decode_estimate", -# "decode_topk", -# "decode_sparse_attn", -# "rms_norm_forward", -# "apply_rope_in_place", -# ] - -# def apply_rope_in_place( -# q: torch.Tensor, -# k: torch.Tensor, -# past_kv_len: int, -# rope_scale: Optional[float] = None, -# rope_theta: Optional[float] = None, -# ): -# """ -# Semantics of `apply_rope_in_place`: -# Apply RoPE (Relative Positional Encoding) in-place. -# On q, k which is generated by GEMM. Layout is naturally NHD. - -# Args: -# q: Shape: `[N, H, D]`. -# k: Shape: `[N, H, D]`. -# past_kv_len: Length of past KV cache. Used to calculate frequency. -# """ -# if rope_scale is None: -# rope_scale = 1.0 -# if rope_theta is None: -# rope_theta = 1e4 -# _kernels.apply_rope_in_place( -# q, -# k, -# past_kv_len, -# rope_scale, -# rope_theta, -# ) - -# def rms_norm_forward( -# input: torch.Tensor, -# weight: torch.Tensor, -# epsilon: float, -# ) -> torch.Tensor: -# o = torch.empty_like(input, dtype=input.dtype, device=input.device) -# f = _kernels.rms_norm_forward -# f( -# input, -# weight, -# o, -# epsilon, -# ) -# return o - -# def append_kv( -# k: torch.Tensor, -# v: torch.Tensor, -# iController: InferenceController, -# layer_idx: int, -# ): -# """ -# Semantics of `append_kv`: -# Append new generated k/v into kv cache and meta data cache. -# Automatically dispatch to Prefill / Decode Kernel - -# Notations for shapes: -# `B`: batch size -# `N`: number of heads -# `D`: head dimension -# `L`: number of layers -# `MAXLEN`: maximum length of the KV cache - -# Args: -# k: Shape: `[B, N, D]`. Key projection (`X @ W_k`). -# v: Shape: `[B, N, D]`. Value projection (`X @ W_v`). -# iController: InferenceController object, which contains all needed information. -# layer_idx: Layer index of the KV cache. -# """ -# seq_len = k.size(0) -# if seq_len > 1: -# _kernels.append_kv_cache_prefill( -# k, -# v, -# iController.kv_cache.buf_layer(layer_idx), -# iController.kv_indices_with_last, -# iController.kv_indptr_for_append, -# iController.kv_cache.last_page_len, -# iController.kv_last_page_idx, -# iController.metadata_cache.buf_layer(layer_idx), -# iController.metadata_indices, -# iController.metadata_indptr_for_append, -# iController.metadata_cache.last_page_len, -# iController.metadata_last_page_idx, -# iController.layout -# ) -# else: -# _kernels.append_kv_cache_decode( -# k, -# v, -# iController.kv_cache.buf_layer(layer_idx), -# iController.kv_indices_with_last, -# iController.kv_indptr_for_append, -# iController.kv_cache.last_page_len, -# iController.kv_last_page_idx, -# iController.metadata_cache.buf_layer(layer_idx), -# iController.metadata_indices, -# iController.metadata_indptr_for_append, -# iController.metadata_cache.last_page_len, -# iController.metadata_last_page_idx, -# iController.layout -# ) - -# def prefill_forward( -# q: torch.Tensor, -# iController: InferenceController, -# layer_idx: int, -# rope_scale: Optional[float] = None, -# rope_theta: Optional[float] = None, -# ) -> torch.Tensor: -# """ -# Semantics of `prefill_forward`: -# New genrated K/Vs are already in the kv cache and meta data cache (well-maintained). -# Perform FlashInfer Self-Attention with Casual Attention. -# Note that we not have position shift and current version not support Prefill Optimization. - -# Notations for shapes: -# `B`: batch size -# `N`: number of heads -# `D`: head dimension -# `L`: number of layers -# `MAXLEN`: maximum length of the KV cache - -# Args: -# q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). -# iController: InferenceController object, which contains all needed information. -# layer_idx: Layer index of the KV cache. -# """ -# if rope_scale is None: -# rope_scale = 1.0 -# if rope_theta is None: -# rope_theta = 1e4 - -# f = _kernels.prefill_with_paged_kv_cache -# o = f( -# q, -# iController.kv_cache.buf_layer(layer_idx), -# iController.kv_indices_with_last, -# iController.kv_cache.last_page_len, -# True, # Casual -# iController.layout, -# False, # FP16 Accumulator for 4090 -# rope_scale, -# rope_theta, -# ) -# return o - -# def decode_estimate( -# q: torch.Tensor, -# iController: InferenceController, -# layer_idx: int, -# ) -> torch.Tensor: -# """ -# Semantics of `decode_estimate`: -# When decoding, estimate the attention score for each page. - -# Notations for shapes: -# `B`: batch size -# `N`: number of heads -# `D`: head dimension -# `L`: number of layers -# `MAXLEN`: maximum length of the KV cache - -# Args: -# q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). -# iController: InferenceController object, which contains all needed information. -# layer_idx: Layer index of the KV cache. -# """ -# f = _kernels.estimate_attn_score -# # (iController.metadata_cache.seqlen - 1) is manually excluding the last elements, which is the current page. -# o = torch.empty((iController.num_heads, iController.metadata_cache.seqlen - 1), dtype=q.dtype, device=q.device) -# f( -# q, -# o, -# iController.metadata_cache.buf_layer(layer_idx), -# iController.metadata_indices, -# iController.metadata_indptr_for_append, -# iController.metadata_cache.last_page_len, # One entry delta is considered by kernel-level implementation -# iController.metadata_last_page_idx, -# iController.layout, -# ) -# return o - -# def decode_topk( -# estimated_attn_score: torch.Tensor, -# iController: InferenceController, -# ): -# """ -# Semantics of `decode_topk`: -# select top-k pages with highest attention score. - -# Notations for shapes: -# `B`: batch size -# `N`: number of heads -# `D`: head dimension -# `L`: number of layers -# `MAXLEN`: maximum length of the KV cache - -# Args: -# q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). -# iController: InferenceController object, which contains all needed information. -# layer_idx: Layer index of the KV cache. -# """ -# # excluding the last page -# page_budet = iController.inference_page_budget - 1 -# f = _kernels.topk_filtering -# f( -# estimated_attn_score, -# iController.kv_indices_without_last, -# iController.topk_dout_buffer, -# iController.topk_dindices_buffer, -# iController.topk_buf, -# page_budet, -# ) - -# def decode_sparse_attn( -# q: torch.Tensor, -# iController: InferenceController, -# layer_idx: int, -# topk_indices: torch.Tensor, -# rope_scale: Optional[float] = None, -# rope_theta: Optional[float] = None, -# ) -> torch.Tensor: -# """ -# Semantics of `decode_sparse_attn`: -# Excute self-attention only on the selected pages (Top-k output) - -# Notations for shapes: -# `B`: batch size -# `N`: number of heads -# `D`: head dimension -# `L`: number of layers -# `MAXLEN`: maximum length of the KV cache - -# Args: -# q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). -# iController: InferenceController object, which contains all needed information. -# layer_idx: Layer index of the KV cache. -# topk_indices: Shape: `[N, page_budget-1]`. Top-k indices. -# """ -# o = torch.empty_like(q, dtype=q.dtype, device=q.device) -# iController._decode_handler.forward( -# q, -# o, -# iController.kv_cache.buf_layer(layer_idx), -# topk_indices, -# iController.kv_indptr_for_approx_decode, -# iController.kv_cache.last_page_len, -# iController.kv_last_page_idx, -# rope_scale, -# rope_theta, -# ) -# return o \ No newline at end of file diff --git a/quest/utils/cache_utils.py.bak b/quest/utils/cache_utils.py.bak deleted file mode 100644 index d6b3686..0000000 --- a/quest/utils/cache_utils.py.bak +++ /dev/null @@ -1,383 +0,0 @@ -import copy -import importlib.metadata -import json -import os -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch - -from transformers.utils import logging -from transformers.utils.deprecation import deprecate_kwarg - - -logger = logging.get_logger(__name__) - - -class Cache(torch.nn.Module): - """ - Base, abstract class for all caches. The actual data structure is specific to each subclass. - """ - - def __init__(self): - super().__init__() - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. These are specific to each subclass and allow new types of - cache to be created. - - Return: - A tuple containing the updated key and value states. - """ - raise NotImplementedError("Make sure to implement `update` in a subclass.") - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") - - def get_max_length(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states, if there is any.""" - raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") - - def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: - """Given the sequence length of the new inputs, returns the usable length of the cache.""" - # Cache without size limit -> all cache is usable - # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache - # length, we will need to evict part of the cache (and thus not all cache is usable) - max_length = self.get_max_length() - previous_seq_length = self.get_seq_length(layer_idx) - if max_length is not None and previous_seq_length + new_seq_length > max_length: - return max_length - new_seq_length - return previous_seq_length - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx] != []: - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - if self.value_cache[layer_idx] != []: - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - def reset_cache(): - """ - Reset the cache to its initial state. - """ - raise NotImplementedError("Make sure to implement `reset_cache` in a subclass.") - - - @property - def seen_tokens(self): - logger.warning_once( - "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " - "model input instead." - ) - if hasattr(self, "_seen_tokens"): - return self._seen_tokens - else: - return None - - -@dataclass -class CacheConfig: - """ - Base class for cache configs - """ - - cache_implementation: None - - @classmethod - def from_dict(cls, config_dict, **kwargs): - """ - Constructs a CacheConfig instance from a dictionary of parameters. - Args: - config_dict (Dict[str, Any]): Dictionary containing configuration parameters. - **kwargs: Additional keyword arguments to override dictionary values. - - Returns: - CacheConfig: Instance of CacheConfig constructed from the dictionary. - """ - config = cls(**config_dict) - to_remove = [] - for key, value in kwargs.items(): - if hasattr(config, key): - setattr(config, key, value) - to_remove.append(key) - for key in to_remove: - kwargs.pop(key, None) - return config - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file - def to_json_file(self, json_file_path: Union[str, os.PathLike]): - """ - Save this instance to a JSON file. - - Args: - json_file_path (`str` or `os.PathLike`): - Path to the JSON file in which this configuration instance's parameters will be saved. - use_diff (`bool`, *optional*, defaults to `True`): - If set to `True`, only the difference between the config instance and the default - `QuantizationConfig()` is serialized to JSON file. - """ - with open(json_file_path, "w", encoding="utf-8") as writer: - config_dict = self.to_dict() - json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" - - writer.write(json_string) - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict - def to_dict(self) -> Dict[str, Any]: - """ - Serializes this instance to a Python dictionary. Returns: - `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. - """ - return copy.deepcopy(self.__dict__) - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ - def __iter__(self): - """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" - for attr, value in copy.deepcopy(self.__dict__).items(): - yield attr, value - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ - def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" - - def to_json_string(self): - """ - Serializes this instance to a JSON formatted string. - Returns: - str: JSON formatted string representing the configuration instance. - """ - return json.dumps(self.__dict__, indent=2) + "\n" - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update - def update(self, **kwargs): - """ - Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, - returning all the unused kwargs. - - Args: - kwargs (`Dict[str, Any]`): - Dictionary of attributes to tentatively update this class. - - Returns: - `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. - """ - to_remove = [] - for key, value in kwargs.items(): - if hasattr(self, key): - setattr(self, key, value) - to_remove.append(key) - - # Remove all the attributes that were updated, without modifying the input dict - unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} - return unused_kwargs - - -class SinkCache(Cache): - """ - A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to - generate beyond the length of its context window, without losing fluency in the conversation. As it discards past - tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. - - It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is - `[batch_size, num_heads, seq_len, head_dim]`. - - Parameters: - window_length (`int`): - The length of the context window. - num_sink_tokens (`int`): - The number of sink tokens. See the original paper for more information. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache - - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - SinkCache() - ``` - """ - - def __init__(self, window_length: int, num_sink_tokens: int) -> None: - super().__init__() - self.key_cache: List[torch.Tensor] = [] - self.value_cache: List[torch.Tensor] = [] - self.window_length = window_length - self.num_sink_tokens = num_sink_tokens - self.cos_sin_rerotation_cache = {} - self._cos_cache = None - self._sin_cache = None - self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen - - @staticmethod - def _rotate_half(x): - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def _apply_key_rotary_pos_emb( - self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor - ) -> torch.Tensor: - rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) - return rotated_key_states - - def _get_rerotation_cos_sin( - self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - if key_states.shape[-2] not in self.cos_sin_rerotation_cache: - # Upcast to float32 temporarily for better accuracy - cos = cos.to(torch.float32) - sin = sin.to(torch.float32) - - # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence - original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] - shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] - original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] - shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] - rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin - rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin - - self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( - rerotation_cos.to(key_states.dtype).unsqueeze(0), - rerotation_sin.to(key_states.dtype).unsqueeze(0), - ) - return self.cos_sin_rerotation_cache[key_states.shape[-2]] - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length - if len(self.key_cache) <= layer_idx: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_max_length(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states.""" - return self.window_length - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, - `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the - rotation as the tokens are shifted. - - Return: - A tuple containing the updated key and value states. - """ - # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models - # with partially rotated position embeddings, like Phi or Persimmon. - sin = cache_kwargs.get("sin") - cos = cache_kwargs.get("cos") - partial_rotation_size = cache_kwargs.get("partial_rotation_size") - using_rope = cos is not None and sin is not None - - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - # Update the sin/cos cache, which holds sin/cos values for all possible positions - if using_rope and layer_idx == 0: - # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove - # after all RoPE models have a llama-like cache utilization. - if cos.dim() == 2: - self._cos_cache = cos - self._sin_cache = sin - else: - if self._cos_cache is None: - self._cos_cache = cos[0, ...] - self._sin_cache = sin[0, ...] - elif self._cos_cache.shape[0] < self.window_length: - self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) - self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) - - # [bsz, num_heads, seq_len, head_dim] - if len(self.key_cache) <= layer_idx: - # Empty cache - self.key_cache.append(key_states) - self.value_cache.append(value_states) - - elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: - # Growing cache - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - else: - # Shifting cache - keys_to_keep = self.key_cache[layer_idx][ - :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : - ] - - # On RoPE models, we need to recompute the Key rotation as the tokens are shifted - if using_rope: - rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( - key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length] - ) - if partial_rotation_size is not None: - keys_to_keep, keys_pass = ( - keys_to_keep[..., :partial_rotation_size], - keys_to_keep[..., partial_rotation_size:], - ) - keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin) - if partial_rotation_size is not None: - keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) - - # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens - sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] - self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) - - sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] - values_to_keep = self.value_cache[layer_idx][ - :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : - ] - self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - - def reset_cache(self): - self.__init__(self.window_length, self.num_sink_tokens) \ No newline at end of file