diff --git a/README.md b/README.md index 07d3113..2e8f772 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,10 @@ make -j 4. Build end-to-end operators with PyBind ``` # This will automatically build and link the operators -cd quest/ops +cd quest/quest-ops +bash setup.sh +cd - +cd quest/raas-ops bash setup.sh ``` diff --git a/benchmarks/evals/e2e/main.py b/benchmarks/evals/e2e/main.py index ae2ce8a..86960e6 100644 --- a/benchmarks/evals/e2e/main.py +++ b/benchmarks/evals/e2e/main.py @@ -49,12 +49,18 @@ class EvalConfigs: all_approaches: List[str] = field( default_factory=lambda: [ "full", + "full_optimized", "sink-64", "sink-128", "sink-256", "sink-512", "sink-1024", - "h2o-64", + "sink_optimized-64", + "sink_optimized-128", + "sink_optimized-256", + "sink_optimized-512", + "sink_optimized-1024", + "h2o-84", "h2o-128", "h2o-256", "h2o-512", @@ -64,11 +70,21 @@ class EvalConfigs: "quest-256", "quest-512", "quest-1024", + "quest_optimized-64", + "quest_optimized-128", + "quest_optimized-256", + "quest_optimized-512", + "quest_optimized-1024", "raas-64", "raas-128", "raas-256", "raas-512", "raas-1024", + "raas_optimized-64", + "raas_optimized-128", + "raas_optimized-256", + "raas_optimized-512", + "raas_optimized-1024", ] ) @@ -184,15 +200,26 @@ def load_model_for_approach(self, model_name: str, approach_name: str) -> AutoMo model_config = self.configs.model_config if model_config.model_type == "llama": - from transformers import LlamaForCausalLM - if approach_name == "full" or "sink" in approach_name: # They differ only in cache type + optimized = ("optimized" in approach_name) + + if ("full" in approach_name or "sink" in approach_name) and not optimized: # They differ only in cache type + from transformers import LlamaForCausalLM model = LlamaForCausalLM.from_pretrained( model_name, device_map="cuda:0", trust_remote_code=True, ) + elif ("full" in approach_name or "sink" in approach_name) and optimized: + from quest.models.full_llama_optimized import LlamaForCausalLM + model = LlamaForCausalLM.from_pretrained( + model_name, + device_map="cuda:0", + trust_remote_code=True, + torch_dtype=torch.float16, # Use float16 for optimized version + ) elif "h2o" in approach_name: + from transformers import LlamaForCausalLM from quest.models.h2o_llama import enable_h2o_attention_eval model = LlamaForCausalLM.from_pretrained( @@ -204,9 +231,25 @@ def load_model_for_approach(self, model_name: str, approach_name: str) -> AutoMo model, {"cache_budget": int(approach_name.split("-")[-1])}, ) - elif "quest" in approach_name: + elif "quest" in approach_name and optimized: + from quest.models.quest_llama_optimized import LlamaForCausalLM + from quest.models.quest_llama_optimized import enable_quest_attention_eval + model = LlamaForCausalLM.from_pretrained( + model_name, + device_map="cuda:0", + trust_remote_code=True, + torch_dtype=torch.float16, # Use float16 for optimized version + ) + enable_quest_attention_eval( + model, + { + "cache_budget": int(approach_name.split("-")[-1]), + "page_size": 16, # Fixed as stated in the paper + }, + ) + elif "quest" in approach_name and not optimized: + from transformers import LlamaForCausalLM from quest.models.quest_llama import enable_quest_attention_eval - model = LlamaForCausalLM.from_pretrained( model_name, device_map="cuda:0", @@ -219,7 +262,24 @@ def load_model_for_approach(self, model_name: str, approach_name: str) -> AutoMo "page_size": 16, # Fixed as stated in the paper }, ) - elif "raas" in approach_name: + elif "raas" in approach_name and optimized: + from quest.models.raas_llama_optimized import LlamaForCausalLM + from quest.models.raas_llama_optimized import enable_raas_attention_eval + model = LlamaForCausalLM.from_pretrained( + model_name, + device_map="cuda:0", + trust_remote_code=True, + torch_dtype=torch.float16, # Use float16 for optimized version + ) + enable_raas_attention_eval( + model, + { + "cache_budget": int(approach_name.split("-")[-1]), + "page_size": 16, # Fixed as stated in the paper + }, + ) + elif "raas" in approach_name and not optimized: + from transformers import LlamaForCausalLM from quest.models.raas_llama import enable_raas_attention_eval model = LlamaForCausalLM.from_pretrained( @@ -337,7 +397,7 @@ def test_model( cache_position = torch.arange(input_ids.shape[1], dtype=torch.int64, device="cuda:0") # Initialize the cache - if self.configs.approach == "full": + if self.configs.approach in ["full", "full_optimized"]: past_key_values = DynamicCache() elif "sink" in self.configs.approach: cache_budget = int(self.configs.approach.split("-")[-1]) @@ -355,7 +415,6 @@ def test_model( cache_budget = int(self.configs.approach.split("-")[-1]) past_key_values = RaaSCache(page_size=16, cache_budget=cache_budget) - with torch.no_grad(): # Prefill @@ -407,6 +466,8 @@ def test_model( JCT = prefill_time + np.sum(decode_time) TPOT = np.sum(decode_time) / num_decode + if "optimized" in self.configs.approach: + pipe.model.reset_model() model_output = pipe.tokenizer.decode(generated_content, skip_special_tokens=True) return model_output, TTFT, JCT, TPOT, num_decode diff --git a/benchmarks/evals/to_decodenum/__init__.py b/benchmarks/evals/to_decodenum/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/evals/to_decodenum/main.py b/benchmarks/evals/to_decodenum/main.py new file mode 100644 index 0000000..3f31435 --- /dev/null +++ b/benchmarks/evals/to_decodenum/main.py @@ -0,0 +1,568 @@ +import argparse +import logging +import os +import sys +import time +from abc import abstractmethod +from collections import defaultdict +from dataclasses import dataclass, field +from typing import List, Tuple, Dict + +import numpy as np +import torch +from tqdm.contrib import tenumerate +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + DynamicCache, + Pipeline, + SinkCache, + pipeline, +) + +from benchmarks.data_sets.data_set import Data_set +from benchmarks.evals.utils import str2class + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@dataclass +class EvalConfigs: + dataset: str + model: str + approach: str + tot_num_data: int = 5 # used as repeat + all_datasets: List[str] = field( + default_factory=lambda: ["math500", "aime", "gsm8k"] + ) # Fixed mutable default + all_models: List[str] = field( + default_factory=lambda: [ + "peiyi9979/mistral-7b-sft", + "AIDC-AI/Marco-o1", + "Qwen/Qwen2.5-Math-7B-Instruct", + ] + ) + all_approaches: List[str] = field( + default_factory=lambda: [ + "full", + "full_optimized", + "sink-64", + "sink-128", + "sink-256", + "sink-512", + "sink-1024", + "sink_optimized-64", + "sink_optimized-128", + "sink_optimized-256", + "sink_optimized-512", + "sink_optimized-1024", + "h2o-84", + "h2o-128", + "h2o-256", + "h2o-512", + "h2o-1024", + "quest-64", + "quest-128", + "quest-256", + "quest-512", + "quest-1024", + "quest_optimized-64", + "quest_optimized-128", + "quest_optimized-256", + "quest_optimized-512", + "quest_optimized-1024", + "raas-64", + "raas-128", + "raas-256", + "raas-512", + "raas-1024", + "raas_optimized-64", + "raas_optimized-128", + "raas_optimized-256", + "raas_optimized-512", + "raas_optimized-1024", + ] + ) + all_decode_nums = [64, 128, 256, 512, 1024, 2048, 3072, 4096, 5120, 6144, 7168, 8192] + + batch_size: int = 1 + model_config: AutoConfig = field(init=False) + seed: int = 42 + + result_path: str = "results" + + @classmethod + def get_configs_from_cli_args(cls) -> "EvalConfigs": + """ + Parse the command line arguments and return the Configs object. + """ + # Add the arguments to the parser. + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, required=True) + parser.add_argument("--model", type=str, required=True) + parser.add_argument("--approach", type=str, required=True) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--seed", type=int, default=42) + + # Parse the arguments. + args = parser.parse_args() + configs = cls(**vars(args)) + return configs + + def __post_init__(self): + """ + Verify the init arguments and create the result path. + """ + self._verify_init_args() + self.result_path = os.path.join(self.result_path, self.dataset, self.model.split("/")[-1]) + os.makedirs(self.result_path, exist_ok=True) + + self.model_config = AutoConfig.from_pretrained(self.model) + + def _verify_init_args(self): + assert self.model in self.all_models, f"{self.model} not in {self.all_models}" + assert self.dataset in self.all_datasets, f"{self.dataset} not in {self.all_datasets}" + assert self.approach in self.all_approaches, f"{self.approach} not in {self.all_approaches}" + + +class EvalEngine: + """ + Evaluate a specific approach on a specific model and a specific dataset. + """ + + def __init__(self, configs: EvalConfigs) -> None: + self.configs = configs + + def run(self): + logging.info( + ( + f"Evaluate \033[32m{self.configs.approach}\033[0m on" + f" \033[32m{self.configs.model}\033[0m and" + f" \033[32m{self.configs.dataset}\033[0m" + ) + ) + logging.info(f"Save the results to \033[32m{self.configs.result_path}\033[0m") + + # Step 1: Preprocessing, load and modify neccessary components such as + # tokenizer, dataset, model and pipeline. + self.tokenizer: AutoTokenizer = self.load_tokenizer(self.configs.model) + self.dataset: Data_set = self.load_dataset(self.configs.dataset, self.tokenizer) + self.model: AutoModelForCausalLM = self.load_model_for_approach( + self.configs.model, self.configs.approach + ) + self.pipe: Pipeline = self.load_pipeline(self.model, self.tokenizer) + + # Step 2: Run the inference and record results into the dataset + self.dataset = self.run_inference(self.pipe, self.dataset) + + # Step 3: Generate the presentation + self.generate_presentation() + + def load_tokenizer(self, model_name: str) -> AutoTokenizer: + """ + Load the tokenizer for the model. + """ + + logger.info(f"Loading the tokenizer \033[32m{model_name}\033[0m") + + # Avoid tokenization warnings (deadlock) + os.environ["TOKENIZERS_PARALLELISM"] = "true" + + return AutoTokenizer.from_pretrained( + model_name, + model_max_length=sys.maxsize, + padding_side="right", + trust_remote_code=True, + ) + + def load_dataset(self, dataset_name: str, tokenizer: AutoTokenizer) -> Data_set: + """ + Load the dataset, finish preprocessing within the Data_set + class and save the dataset. + """ + logger.info(f"Loading the dataset \033[32m{dataset_name}\033[0m") + dataset: Data_set = str2class[dataset_name]( + tokenizer=tokenizer, + path=self.configs.result_path, + tot_num_data=self.configs.tot_num_data, + ) + dataset.save_dataset(self.configs.result_path) + + return dataset + + def load_model_for_approach(self, model_name: str, approach_name: str) -> AutoModelForCausalLM: + """ + Load the model and decide on the type of KV cache. + """ + + logger.info(f"Loading the model \033[32m{model_name}\033[0m") + + model_config = self.configs.model_config + if model_config.model_type == "llama": + + optimized = ("optimized" in approach_name) + + if ("full" in approach_name or "sink" in approach_name) and not optimized: # They differ only in cache type + from transformers import LlamaForCausalLM + model = LlamaForCausalLM.from_pretrained( + model_name, + device_map="cuda:0", + trust_remote_code=True, + ) + elif ("full" in approach_name or "sink" in approach_name) and optimized: + from quest.models.full_llama_optimized import LlamaForCausalLM + model = LlamaForCausalLM.from_pretrained( + model_name, + device_map="cuda:0", + trust_remote_code=True, + torch_dtype=torch.float16, # Use float16 for optimized version + attn_implementation="flash_attention_2", + ) + elif "h2o" in approach_name: + from transformers import LlamaForCausalLM + from quest.models.h2o_llama import enable_h2o_attention_eval + + model = LlamaForCausalLM.from_pretrained( + model_name, + device_map="cuda:0", + trust_remote_code=True, + ) + enable_h2o_attention_eval( + model, + {"cache_budget": int(approach_name.split("-")[-1])}, + ) + elif "quest" in approach_name and optimized: + from quest.models.quest_llama_optimized import LlamaForCausalLM + from quest.models.quest_llama_optimized import enable_quest_attention_eval + model = LlamaForCausalLM.from_pretrained( + model_name, + device_map="cuda:0", + trust_remote_code=True, + torch_dtype=torch.float16, # Use float16 for optimized version + ) + enable_quest_attention_eval( + model, + { + "cache_budget": int(approach_name.split("-")[-1]), + "page_size": 16, # Fixed as stated in the paper + "max_seq_len": 8192 + 1024, + }, + ) + elif "quest" in approach_name and not optimized: + from transformers import LlamaForCausalLM + from quest.models.quest_llama import enable_quest_attention_eval + model = LlamaForCausalLM.from_pretrained( + model_name, + device_map="cuda:0", + trust_remote_code=True, + ) + enable_quest_attention_eval( + model, + { + "cache_budget": int(approach_name.split("-")[-1]), + "page_size": 16, # Fixed as stated in the paper + }, + ) + elif "raas" in approach_name and optimized: + from quest.models.raas_llama_optimized import LlamaForCausalLM + from quest.models.raas_llama_optimized import enable_raas_attention_eval + model = LlamaForCausalLM.from_pretrained( + model_name, + device_map="cuda:0", + trust_remote_code=True, + torch_dtype=torch.float16, # Use float16 for optimized version + ) + enable_raas_attention_eval( + model, + { + "cache_budget": int(approach_name.split("-")[-1]), + "page_size": 16, # Fixed as stated in the paper + "max_seq_len": 8192 + 1024, + }, + ) + elif "raas" in approach_name and not optimized: + from transformers import LlamaForCausalLM + from quest.models.raas_llama import enable_raas_attention_eval + + model = LlamaForCausalLM.from_pretrained( + model_name, + device_map="cuda:0", + trust_remote_code=True, + ) + enable_raas_attention_eval( + model, + { + "cache_budget": int(approach_name.split("-")[-1]), + "page_size": 16, # Fixed as stated in the paper + }, + ) + elif model_config.model_type == "qwen2": + from transformers import Qwen2ForCausalLM + + if approach_name == "full" or "sink" in approach_name: # They differ only in cache type + model = Qwen2ForCausalLM.from_pretrained( + model_name, + device_map="cuda:0", + trust_remote_code=True, + ) + elif "h2o" in approach_name: + from quest.models.h2o_qwen2 import enable_h2o_attention_eval + + model = Qwen2ForCausalLM.from_pretrained( + model_name, + device_map="cuda:0", + trust_remote_code=True, + ) + enable_h2o_attention_eval( + model, + {"cache_budget": int(approach_name.split("-")[-1])}, + ) + elif "quest" in approach_name: + from quest.models.quest_qwen2 import enable_quest_attention_eval + + model = Qwen2ForCausalLM.from_pretrained( + model_name, + device_map="cuda:0", + trust_remote_code=True, + ) + enable_quest_attention_eval( + model, + { + "cache_budget": int(approach_name.split("-")[-1]), + "page_size": 16, # Fixed as stated in the paper + }, + ) + elif "raas" in approach_name: + from quest.models.raas_qwen2 import enable_raas_attention_eval + + model = Qwen2ForCausalLM.from_pretrained( + model_name, + device_map="cuda:0", + trust_remote_code=True, + ) + enable_raas_attention_eval( + model, + { + "cache_budget": int(approach_name.split("-")[-1]), + "page_size": 16, # Fixed as stated in the paper + }, + ) + + return model + + def load_pipeline(self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer) -> Pipeline: + """ + Assemble the pipeline with the model and the tokenizer. + """ + logger.info("Use a pipeline to aggregate the model and the tokenizer") + return pipeline( + "text-generation", + model=model, + tokenizer=tokenizer, + pad_token_id=tokenizer.eos_token_id, + ) + + def run_inference(self, pipe: Pipeline, dataset: Data_set) -> Data_set: + """ + Run the inference and record the results into the dataset. + """ + + logger.info("Run the inference. This might take a long time... Good luck") + results = defaultdict(list) + for i, (prompt, answer) in tenumerate(dataset, desc="dataset", leave=True): + ( + model_output, + TTFT, + JCT, + TPOT, + num_decode, + JCT_decode, + TPOT_decode, + memory_token_decode + ) = self.test_model(pipe, prompt, answer) + # results[f"output_{self.configs.approach}"].append(model_output) + results[f"TTFT_{self.configs.approach}"].append(TTFT) + results[f"JCT_{self.configs.approach}"].append(JCT) + results[f"TPOT_{self.configs.approach}"].append(TPOT) + results[f"num_decode_{self.configs.approach}"].append(num_decode) + results[f"JCT_decode_{self.configs.approach}"].append(JCT_decode) + results[f"TPOT_decode_{self.configs.approach}"].append(TPOT_decode) + results[f"bytes_per_token_{self.configs.approach}"].append(self.get_kv_per_token(pipe.model)) + results[f"memory_token_decode_{self.configs.approach}"].append(memory_token_decode) + dataset.update(results) + dataset.save_dataset(self.configs.result_path) + + return dataset + + def get_kv_per_token(self, model: AutoModelForCausalLM) -> int: + config: AutoConfig = model.config + + hidden_size = config.hidden_size + num_attention_heads = config.num_attention_heads + num_hidden_layers = config.num_hidden_layers + + num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads) + + head_dim = hidden_size // num_attention_heads + + per_layer_token_param = 2 * head_dim * num_key_value_heads + + per_token_param = per_layer_token_param * num_hidden_layers + + bytes_per_token = per_token_param * 2 # for float16 + + return bytes_per_token + + + + + def test_model( + self, pipe: Pipeline, prompt: str, answer: str + ) -> Tuple[str, float, float, float, int, Dict, Dict, Dict]: + + + # test the time + torch.cuda.empty_cache() + # Prepare the input + # try: + # extended_prompt = pipe.tokenizer.apply_chat_template( + # [{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True + # ) + # except Exception as e: + # logger.debug(f"No chat template found. Using the prompt as is.") + # extended_prompt = prompt + # inputs = pipe.tokenizer(extended_prompt, return_tensors="pt", ).to("cuda:0") + pipe.tokenizer.pad_token = pipe.tokenizer.eos_token + batch_size = self.configs.batch_size + inputs = pipe.tokenizer( + [ + "demo" for _ in range(batch_size) + ], + padding="max_length", + max_length=128, + return_tensors="pt" + ).to("cuda:0") + input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"] + cache_position = torch.arange(input_ids.shape[1], dtype=torch.int64, device="cuda:0") + prompt_length = input_ids.shape[1] + + # Initialize the cache + if self.configs.approach in ["full", "full_optimized"]: + past_key_values = DynamicCache() + elif "sink" in self.configs.approach: + cache_budget = int(self.configs.approach.split("-")[-1]) + past_key_values = SinkCache(window_length=cache_budget, num_sink_tokens=4) + elif "h2o" in self.configs.approach: + from quest.utils.cache_utils import H2OCache + + cache_budget = int(self.configs.approach.split("-")[-1]) + past_key_values = H2OCache(cache_budget=cache_budget) + elif "quest" in self.configs.approach: + # Modifications happen on the model loading stage instead of here + past_key_values = DynamicCache() # quest attention will not discard any cache + elif "raas" in self.configs.approach: + from quest.utils.cache_utils import RaaSCache + + cache_budget = int(self.configs.approach.split("-")[-1]) + past_key_values = RaaSCache(page_size=16, cache_budget=cache_budget) + + JCT_decode = defaultdict(float) + TPOT_decode = defaultdict(float) + memory_token_decode = defaultdict(int) + + with torch.no_grad(): + + # Prefill + start_time = time.perf_counter() + output = pipe.model( + input_ids=input_ids, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + use_cache=True, + ) + prefill_time = time.perf_counter() - start_time + + next_token_id = output.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) + # generated_content = [next_token_id.item()] + + # Decode autoregressively + decode_time = [] + max_decode_num = max(EvalConfigs.all_decode_nums) + for num_decode in range(1, max_decode_num + 1): + input_ids = next_token_id + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + cache_position = cache_position[-1:] + 1 + + start_time = time.perf_counter() + outputs = pipe.model( + input_ids=input_ids, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + use_cache=True, + ) + decode_time.append(time.perf_counter() - start_time) + + # Produece the next token + next_token_id = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) + # generated_content += [next_token_id.item()] + + # # ignore the eos token + # if next_token_id.item() == pipe.tokenizer.eos_token_id: + # break + if num_decode in EvalConfigs.all_decode_nums: + JCT_decode[num_decode] = prefill_time + np.sum(decode_time) + TPOT_decode[num_decode] = np.sum(decode_time) / num_decode + + # calculate the space + approach = self.configs.approach + if "raas" in approach: + memory_token = min(cache_budget, prompt_length + num_decode) + elif "sink" in approach: + memory_token = min(4 + cache_budget, prompt_length + num_decode) + else: # full, quest + memory_token = prompt_length + num_decode + memory_token_decode[num_decode] = memory_token + + + TTFT = prefill_time + JCT = prefill_time + np.sum(decode_time) + TPOT = np.sum(decode_time) / num_decode + + if "optimized" in self.configs.approach: + pipe.model.reset_model() + # model_output = pipe.tokenizer.decode(generated_content, skip_special_tokens=True) + return "", TTFT, JCT, TPOT, num_decode, JCT_decode, TPOT_decode, memory_token_decode + + def generate_presentation(self): + """ + Present the results by invoking this function after executing run(). + Separating this function from run() improves efficiency by saving execution time. + The results are saved in self.configs.result_path. + """ + + # self.dataset.calc_accuracy(self.configs.approach) + self.dataset.save_dataset(self.configs.result_path) + + # accuracy_avg = np.mean(self.dataset.data[f"accuracy_{self.configs.approach}"]) + TTFT_avg = np.mean(self.dataset.data[f"TTFT_{self.configs.approach}"]) + JCT_avg = np.mean(self.dataset.data[f"JCT_{self.configs.approach}"]) + TPOT_avg = np.mean(self.dataset.data[f"TPOT_{self.configs.approach}"]) + num_decode_avg = np.mean(self.dataset.data[f"num_decode_{self.configs.approach}"]) + # logger.info(f"Average accuracy of {self.configs.approach}: {accuracy_avg:.3f}") + logger.info(f"Average TTFT of {self.configs.approach}: {TTFT_avg:.2f} s") + logger.info(f"Average JCT of {self.configs.approach}: {JCT_avg:.2f} s") + logger.info(f"Average TPOT of {self.configs.approach}: {TPOT_avg:.2f} s") + logger.info(f"Average num_decode of {self.configs.approach}: {num_decode_avg:.2f}") + + +if __name__ == "__main__": + + configs = EvalConfigs.get_configs_from_cli_args() + eval_engine = EvalEngine(configs) + eval_engine.run() diff --git a/benchmarks/evals/to_decodenum/plot_time_memory_to_decodenum.py b/benchmarks/evals/to_decodenum/plot_time_memory_to_decodenum.py new file mode 100644 index 0000000..b7f59c2 --- /dev/null +++ b/benchmarks/evals/to_decodenum/plot_time_memory_to_decodenum.py @@ -0,0 +1,120 @@ +import os +from typing import List, Tuple + +import numpy as np +import pandas as pd +from matplotlib import pyplot as plt +import logging + +logging.basicConfig(level=logging.INFO) + +logger = logging.getLogger(__name__) + +from benchmarks.evals.utils import ( + dataset_metrics_map, + dataset_names_map, + model_names_map, +) + +dataset = "gsm8k" +model = "peiyi9979/mistral-7b-sft" +all_approaches = [ + # "raas_optimized-64", + # "raas_optimized-128", + # "raas_optimized-256", + "raas_optimized-512", + "raas_optimized-1024", + # "quest_optimized-64", + # "quest_optimized-128", + # "quest_optimized-256", + # "quest_optimized-512", + "quest_optimized-1024", + "full_optimized", +] + + +def get_line_style(approach: str) -> str | tuple[str, tuple[float, tuple[float, float]]]: + idx = all_approaches.index(approach) + return [ + (0, (3.7, 1.6)), + (2.5, (3.7, 1.6)), + (0.2, (1.5, 1.65)), + "--", + ][idx] + # return '-' + + +def get_label(s: str): + return ( + s.replace("_optimized", "") + .replace("raas", "RaaS") + .replace("quest", "Quest") + .replace("full", "Dense") + ) + + +if __name__ == "__main__": + label_fontsize = 12 + fig, axs = plt.subplots(1, 2, figsize=(6, 2.7)) + plt.subplots_adjust(wspace=0.3) + + last_model_name = model.split("/")[-1] + path = f"results/{dataset}/{last_model_name}/data.json" + dataset = pd.read_json(path) + + # draw the time to decode_num + + for approach in all_approaches: + xs = [] + ys = [] + x_key = f"JCT_decode_{approach}" + if x_key not in dataset.columns: + logger.warning(f"Key {x_key} not found in the dataset") + continue + JCT = dataset[x_key] + dp = len(JCT) + # bonus = 2 if approach == "raas_optimized-1024" else 0 # to show the nearly overlapped lines more clearly + for i in JCT[0].keys(): + xs.append(int(i) / 1024) + ys.append(np.mean([JCT[j][i] for j in range(dp)])) + axs[0].plot( + xs, ys, label=get_label(approach), linestyle=get_line_style(approach) + ) + + axs[0].legend() + axs[0].set_xlabel("# decode tokens / k", fontsize=label_fontsize) + axs[0].set_ylabel("JCT / s", fontsize=label_fontsize) + + # draw the memory to decode_num + + for approach in all_approaches: + xs = [] + ys = [] + bytes_per_token_key = f"bytes_per_token_{approach}" + memory_token_key = f"memory_token_decode_{approach}" + if bytes_per_token_key not in dataset.columns: + logger.warning(f"Key {bytes_per_token_key} not found in the dataset") + continue + if memory_token_key not in dataset.columns: + logger.warning(f"Key {memory_token_key} not found in the dataset") + continue + bytes_per_token = dataset[bytes_per_token_key] + memory_token = dataset[memory_token_key] + bonus = 0.003 if approach == "quest_optimized-1024" else 0 + for i in memory_token[0].keys(): + xs.append(int(i) / 1024) + ys.append( + np.mean([bytes_per_token * memory_token[j][i] for j in range(dp)]) + / 1024**3 + + bonus + ) + axs[1].plot( + xs, ys, label=get_label(approach), linestyle=get_line_style(approach) + ) + + axs[1].set_xlabel("# decode tokens / k", fontsize=label_fontsize) + axs[1].set_ylabel("KV Cache / GB", fontsize=label_fontsize) + + # Save + + plt.savefig("results/fig-eval-time-memory.pdf", format="pdf", bbox_inches="tight") diff --git a/benchmarks/evals/to_decodenum/run_all.sh b/benchmarks/evals/to_decodenum/run_all.sh new file mode 100644 index 0000000..e015588 --- /dev/null +++ b/benchmarks/evals/to_decodenum/run_all.sh @@ -0,0 +1,31 @@ +#!/bin/bash +SessionName="e2e" +all_datasets=("gsm8k" "aime" "math500") +all_models=("peiyi9979/mistral-7b-sft" "Qwen/Qwen2.5-Math-7B-Instruct") + +# Create a new tmux session +tmux new-session -d -s ${SessionName} + +# Create len(all_datasets) * len(all_models) windows in tmux +counter=2 +for dataset in ${all_datasets[@]}; do + for model in ${all_models[@]}; do + # Create a new window for each dataset-model pair + if [ $counter -eq 0 ]; then + tmux rename-window -t ${SessionName}:${counter} "${dataset:0:4}-${model:0:4}" + else + tmux new-window -t ${SessionName}:${counter} -n "${dataset:0:4}-${model:0:4}" + fi + + # Set the GPU ID for the new window + tmux send-keys -t ${SessionName}:${counter} "export CUDA_VISIBLE_DEVICES=${counter}" C-m + + # Run the command in the new window + tmux send-keys -t ${SessionName}:${counter} "bash run_one.sh ${dataset} ${model}" C-m + + counter=$((counter+1)) + done +done + +# Attach to the tmux session +tmux attach-session -t ${SessionName} \ No newline at end of file diff --git a/benchmarks/evals/to_decodenum/run_one.sh b/benchmarks/evals/to_decodenum/run_one.sh new file mode 100644 index 0000000..7e9697e --- /dev/null +++ b/benchmarks/evals/to_decodenum/run_one.sh @@ -0,0 +1,43 @@ + +dataset="gsm8k" +model="peiyi9979/mistral-7b-sft" +# all_approaches=("full" "sink-64" "sink-128" "sink-256" "sink-512" "sink-1024" "quest-64" "quest-128" "quest-256" "quest-512" "quest-1024") +# all_approaches=("raas-64" "raas-128" "raas-256" "raas-512" "raas-1024") +all_approaches=( + # "raas_optimized-64" + # "raas_optimized-128" + # "raas_optimized-256" + # "raas_optimized-512" + # "raas_optimized-1024" + # "quest_optimized-64" + # "quest_optimized-128" + # "quest_optimized-256" + # "quest_optimized-512" + # "quest_optimized-1024" + "full_optimized" + # "sink_optimized-64" + # "sink_optimized-128" + # "sink_optimized-256" + # "sink_optimized-512" + # "sink_optimized-1024" +) + +# Take the arguments from the command line +if [ $# -eq 0 ]; then + echo "No arguments provided. Using default values." +elif [ $# -eq 1 ]; then + dataset=$1 +elif [ $# -eq 2 ]; then + dataset=$1 + model=$2 +else + echo "Too many arguments provided. Usage: $0 [dataset] [model]" + exit 1 +fi + + +for approach in ${all_approaches[@]}; do + command="python3 main.py --dataset ${dataset} --model ${model} --approach ${approach} --batch-size 4" + echo "Running command: ${command}" + ${command} +done diff --git a/benchmarks/evals/to_decodenum/test.sh b/benchmarks/evals/to_decodenum/test.sh new file mode 100644 index 0000000..a224ed8 --- /dev/null +++ b/benchmarks/evals/to_decodenum/test.sh @@ -0,0 +1,8 @@ +dataset="gsm8k" +model="peiyi9979/mistral-7b-sft" +approach="raas_optimized-64" + +command="python3 main.py --dataset ${dataset} --model ${model} --approach ${approach}" +echo "Running command: ${command}" +${command} + diff --git a/quest/models/full_llama.py b/quest/models/full_llama.py new file mode 100644 index 0000000..b48100f --- /dev/null +++ b/quest/models/full_llama.py @@ -0,0 +1,1474 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_llama import LlamaConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf" +_CONFIG_FOR_DOC = "LlamaConfig" + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[LlamaConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, *args, **kwargs): + logger.warning_once( + "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " + "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." + ) + kwargs["rope_type"] = "linear" + super().__init__(*args, **kwargs) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, *args, **kwargs): + logger.warning_once( + "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " + "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " + "__init__)." + ) + kwargs["rope_type"] = "dynamic" + super().__init__(*args, **kwargs) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/quest/models/full_llama_optimized.py b/quest/models/full_llama_optimized.py new file mode 100644 index 0000000..8c90698 --- /dev/null +++ b/quest/models/full_llama_optimized.py @@ -0,0 +1,1489 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.models.llama.configuration_llama import LlamaConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf" +_CONFIG_FOR_DOC = "LlamaConfig" + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[LlamaConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, *args, **kwargs): + logger.warning_once( + "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " + "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." + ) + kwargs["rope_type"] = "linear" + super().__init__(*args, **kwargs) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, *args, **kwargs): + logger.warning_once( + "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " + "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " + "__init__)." + ) + kwargs["rope_type"] = "dynamic" + super().__init__(*args, **kwargs) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + # self.num_kv_heads = config.num_key_value_heads # mocked for none-GQA + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + _bsz, q_len, _ = hidden_states.size() + bsz = 1 + hidden_states_first = hidden_states[:1] + query_states = self.q_proj(hidden_states_first) + key_states = self.k_proj(hidden_states_first) + value_states = self.v_proj(hidden_states_first) + + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # change for none-GQA + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # key_states = repeat_kv(key_states, self.num_key_value_groups) + # value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + attn_output = torch.cat([attn_output, hidden_states[1:]], dim=0) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + _bsz, q_len, _ = hidden_states.size() + bsz = 1 + hidden_states_first = hidden_states[:1] + query_states = self.q_proj(hidden_states_first) + value_states = self.v_proj(hidden_states_first) + key_states = self.k_proj(hidden_states_first) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # change for none-GQA + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask[:1], + q_len, + position_ids=position_ids[:1], + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + attn_output = torch.cat([attn_output, hidden_states[1:]], dim=0) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + def reset_model(self): + # Do nothing + pass + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/quest/models/quest_llama_optimized.py b/quest/models/quest_llama_optimized.py index 7cc139b..9badd85 100644 --- a/quest/models/quest_llama_optimized.py +++ b/quest/models/quest_llama_optimized.py @@ -1,998 +1,1014 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Based on HuggingFace Llama Model: models/llama/modeling_llama.py -# transformers==4.31.0 #TODO - -""" PyTorch LLaMA model.""" -import math -from typing import List, Optional, Tuple, Union +# transformers==4.47.0 +from typing import List, Optional, Tuple, Union, Dict import torch -import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, + BaseModelOutputWithPast, + CausalLMOutputWithPast, ) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_utils import PreTrainedModel -from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, ) +from transformers.models.llama.configuration_llama import LlamaConfig + +import quest.quest_utils as quest_utils +from quest.quest_utils.controller import InferenceController -import quest.utils -from quest.utils import rms_norm_forward -from quest.utils.controller import InferenceController logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf" _CONFIG_FOR_DOC = "LlamaConfig" -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat( - [torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1 - ) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - return rms_norm_forward(hidden_states, self.weight, self.variance_epsilon) - - -class LlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype(), - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + return quest_utils.rms_norm_forward(hidden_states, self.weight, self.variance_epsilon) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[LlamaConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class LlamaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.pretraining_tp = config.pretraining_tp - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - if self.pretraining_tp > 1: - slice = self.intermediate_size // self.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat( - [F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1 - ) - up_proj = torch.cat( - [F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1 - ) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [ - F.linear(intermediate_states[i], down_proj_slices[i]) - for i in range(self.pretraining_tp) - ] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - return down_proj - - -class QuestAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig, layer_idx: int): - super().__init__() - self.layer_idx = layer_idx - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.pretraining_tp = config.pretraining_tp - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.v_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self._init_rope() - - def _init_rope(self): - # rope_theta is default to 1e4, as set in RoPE kernel API. - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings - ) - self.rope_scale = 1.0 - else: - scaling_type = self.config.rope_scaling["type"] - if scaling_type == "linear": - # support for Longchat-v1.5. - self.rope_scale = self.config.rope_scaling["factor"] - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - iController: Optional[quest.utils.InferenceController] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - assert bsz == 1, "QuestAttention only supports batch size 1." - assert hasattr(self, "layer_idx"), "QuestAttention requires layer_idx to inference." - - if self.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - torch.cuda.nvtx.range_push("qkv_proj") - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - torch.cuda.nvtx.range_pop() - + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from + (seq_len, num_key_value_heads, head_dim) to (seqlen, num_attention_heads, head_dim) + """ + seq_len, num_key_value_heads, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :].expand(seq_len, num_key_value_heads, n_rep, head_dim) + return hidden_states.reshape(seq_len, num_key_value_heads * n_rep, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.rope_scaling = config.rope_scaling + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + iController: Optional[InferenceController] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + _bsz, q_len, _ = hidden_states.size() + bsz = 1 + + # assert bsz == 1, "QuestAttention only supports batch size 1." + hidden_states_first = hidden_states[0] + ori_dtype = hidden_states.dtype + query_states = self.q_proj(hidden_states_first).to(torch.float16) + key_states = self.k_proj(hidden_states_first).to(torch.float16) + value_states = self.v_proj(hidden_states_first).to(torch.float16) + + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used # Not transposed for Append kv cache NHD layout - query_states = query_states.view(q_len, self.num_heads, self.head_dim) - key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) + query_states = query_states.view(q_len, self.num_heads, self.head_dim) + key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) + + # Hack for GQA: we need to repeat the key and value states to match the number of heads - torch.cuda.nvtx.range_push("RoPE") - quest.utils.apply_rope_in_place( + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + quest_utils.apply_rope_in_place( query_states, key_states, iController.kv_cache.seqlen - q_len, - rope_scale=self.rope_scale, + rope_scale=self.rope_scaling, + rope_theta=self.rope_theta, ) - torch.cuda.nvtx.range_pop() - torch.cuda.nvtx.range_push("append_kv") + # Quest manages KV-Cache internal (with PageAttention) # Here we do not concat / stack # We concat after RoPE - quest.utils.append_kv( + quest_utils.append_kv( key_states, value_states, iController, self.layer_idx, ) - torch.cuda.nvtx.range_pop() - - # Prefill/Decode kernels is different - if q_len > 1: - torch.cuda.nvtx.range_push("prefill_attn") - attn_output = quest.utils.prefill_forward( - query_states, - iController, - self.layer_idx, - ) - torch.cuda.nvtx.range_pop() - else: - # Skipping layers is controled by PAGE_BUDGET, which is set in LlamaModel. - if iController.need_estimate() == False: - torch.cuda.nvtx.range_push("full_attn") - attn_output = quest.utils.decode_sparse_attn( - query_states, - iController, - self.layer_idx, - iController.kv_indices_without_last, - ) - torch.cuda.nvtx.range_pop() - else: - torch.cuda.nvtx.range_push("estimate") - estimated_attn_score = quest.utils.decode_estimate( - query_states, - iController, - self.layer_idx, - ) - torch.cuda.nvtx.range_pop() - - torch.cuda.nvtx.range_push("topk") - quest.utils.decode_topk( - estimated_attn_score, - iController, - ) - torch.cuda.nvtx.range_pop() - - torch.cuda.nvtx.range_push("approx_attn") - attn_output = quest.utils.decode_sparse_attn( - query_states, - iController, - self.layer_idx, - iController.topk_dindices_buffer, - ) - torch.cuda.nvtx.range_pop() - - attn_output = attn_output.unsqueeze(0) # unsqueeze the batch dimension + + if q_len > 1: + attn_output = quest_utils.prefill_forward( + query_states, + iController, + self.layer_idx, + ) + else: + # Skipping layers is controled by PAGE_BUDGET, which is set in LlamaModel. + if not iController.need_estimate(): + attn_output = quest_utils.decode_sparse_attn( + query_states, + iController, + self.layer_idx, + iController.kv_indices_without_last, + ) + else: + estimated_attn_score = quest_utils.decode_estimate( + query_states, + iController, + self.layer_idx, + ) + + quest_utils.decode_topk( + estimated_attn_score, + iController, + ) + + attn_output = quest_utils.decode_sparse_attn( + query_states, + iController, + self.layer_idx, + iController.topk_dindices_buffer, + ) + + attn_output = attn_output.unsqueeze(0) # unsqueeze the batch dimension + # FlashInfer output is naturally NHD # Note that we manully control NHD. Should be more general - if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - torch.cuda.nvtx.range_push("o_proj") - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) - attn_output = sum( - [F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)] - ) - else: - attn_output = self.o_proj(attn_output) - torch.cuda.nvtx.range_pop() - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + attn_output = attn_output.reshape(bsz, q_len, -1).to(ori_dtype) -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = QuestAttention(config=config, layer_idx=layer_idx) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - iController: Optional[InferenceController] = None, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - torch.cuda.nvtx.range_push("input_norm") - hidden_states = self.input_layernorm(hidden_states) - torch.cuda.nvtx.range_pop() - - torch.cuda.nvtx.range_push("LlamaAttention") - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - iController=iController, - ) - torch.cuda.nvtx.range_pop() - hidden_states = residual + hidden_states + attn_output = self.o_proj(attn_output) + attn_output = torch.cat([attn_output, hidden_states[1:]], dim=0) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value - # Fully Connected - residual = hidden_states - torch.cuda.nvtx.range_push("norm") - hidden_states = self.post_attention_layernorm(hidden_states) - torch.cuda.nvtx.range_pop() - torch.cuda.nvtx.range_push("mlp") - hidden_states = self.mlp(hidden_states) - torch.cuda.nvtx.range_pop() - hidden_states = residual + hidden_states - outputs = (hidden_states,) +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + iController: Optional[InferenceController] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states[:1] = self.input_layernorm(hidden_states[:1]) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + iController=iController, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states[:1] = self.post_attention_layernorm(hidden_states[:1]) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) + if output_attentions: + outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) + if use_cache: + outputs += (present_key_value,) - return outputs + return outputs LLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`LlamaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, ) class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): - module.gradient_checkpointing = value + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() LLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ @add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, ) class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - Args: - config: LlamaConfig - """ + Args: + config: LlamaConfig + """ - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [LlamaDecoderLayer(config, i) for i in range(config.num_hidden_layers)] - ) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) - self.gradient_checkpointing = False + self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Leave Quest controller as uninitialized - self.iController = None - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length - ): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ).to(inputs_embeds.device) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - seq_length_with_past = seq_length - past_key_values_length = 0 - - # KV-Cache is managed by iController - # if past_key_values is not None: - # past_key_values_length = past_key_values[0][0].shape[2] - # seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - torch.cuda.nvtx.range_push(f"embed") - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - torch.cuda.nvtx.range_pop() - - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - # Configure Quest Controller - # Prepare indices/indptr for newly appended tokens - assert self.iController is not None, "Please init Quest Controller first." - self.iController.prepare_metadata(seq_length) - - # Skip layers by setting infinite budgets - if self._quest_skip_layer > 0: - self.iController.set_page_budget(self._quest_max_page_limit) - self.iController.begin_forward(seq_length) - - for idx, decoder_layer in enumerate(self.layers): - # Configure regular skipping layers - if idx == self._quest_skip_layer: - self.iController.end_forward() - self.iController.set_page_budget(self._quest_page_budget) - # Avoid the redundant init/copy of metadata - # if previous skip layer does, then skip it again - self.iController.begin_forward(seq_length, updateTensor=(idx == 0)) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - # past_key_value = past_key_values[idx] if past_key_values is not None else None - # KV-Cache Managed by ourselves - past_key_value = None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - torch.cuda.nvtx.range_push(f"layer={idx}") - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - iController=self.iController, - ) - torch.cuda.nvtx.range_pop() - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - # Configure Quest Controller - self.iController.end_forward() - - torch.cuda.nvtx.range_push("lastnorm") - hidden_states = self.norm(hidden_states) - torch.cuda.nvtx.range_pop() - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaForCausalLM(LlamaPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - self.pretraining_tp = config.pretraining_tp - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - self._config = config # saved for quest init - # Initialize weights and apply final processing - self.post_init() - - def quest_init( - self, - page_size: int, - max_seq_len: int, - token_budget: int = 512, - dtype: torch.dtype = torch.float16, - device=torch.device("cuda:0"), - ): - """ - Init function for Quest. Must be called before forwarding. - This function allocates all GPU memory for max_seq_len KV-Cache. - """ - assert self.model.iController is None, "Can't init Quest Controller twice." - - config = self._config - self.model._quest_page_size = page_size - self.model._quest_page_budget = token_budget // page_size # default page budget - self.model._quest_max_page_limit = 1024 * 1024 # arbitraty large size - self.model._quest_skip_layer = 2 - - self.model.iController = InferenceController( - num_layers=config.num_hidden_layers, - num_heads=config.num_attention_heads, - head_dim=config.hidden_size // config.num_attention_heads, - page_size=page_size, - page_budget=self.model._quest_page_budget, - max_seq_len=max_seq_len, # Used for allocating KV Pools - dtype=dtype, - device=device, - ) - - print(f"Quest allocates KV-Cache of {max_seq_len} tokens") - print(f"Token budget is set to {token_budget}") - - def reset_model(self): - """ - Assistant function for cleaning states of KV-Cache, - which prepares for a new conversation. - """ - assert self.model.iController is not None, "Must be called after init." - self.model.iController.clean_states() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - torch.cuda.nvtx.range_push("LlamaForCausalLM") - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - torch.cuda.nvtx.range_push("lm_head") - hidden_states = outputs[0] - if self.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split( - self.vocab_size // self.pretraining_tp, dim=0 - ) - logits = [ - F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp) - ] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() - - torch.cuda.nvtx.range_pop() - torch.cuda.nvtx.range_pop() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past - ), - ) - return reordered_past + self.iController: Optional[InferenceController] = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + seq_length = inputs_embeds.shape[1] + + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + # Configure Quest Controller + # Prepare indices/indptr for newly appended tokens + assert self.iController is not None, "Please init Quest Controller first." + self.iController.prepare_metadata(seq_length) + + # Skip layers by setting infinite budgets + if self._quest_skip_layer > 0: + self.iController.set_page_budget(self._quest_max_page_limit) + self.iController.begin_forward(seq_length) + + for idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + # Configure regular skipping layers + if idx == self._quest_skip_layer: + self.iController.end_forward() + self.iController.set_page_budget(self._quest_page_budget) + # Avoid the redundant init/copy of metadata + # if previous skip layer does, then skip it again + self.iController.begin_forward(seq_length, updateTensor=(idx == 0)) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + iController=self.iController, + ) + + hidden_states = layer_outputs[0] + + # if use_cache: + # next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + self.iController.end_forward() + hidden_states[:1] = self.norm(hidden_states[:1]) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + # if return_legacy_cache: + # next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self._config = config # saved for quest init + + # Initialize weights and apply final processing + self.post_init() + + def quest_init( + self, + page_size: int, + max_seq_len: int, + token_budget: int = 512, + dtype: torch.dtype = torch.float16, + device=torch.device("cuda:0"), + ): + """ + Init function for Quest. Must be called before forwarding. + This function allocates all GPU memory for max_seq_len KV-Cache. + """ + assert self.model.iController is None, "Can't init Quest Controller twice." + + config = self._config + self.model._quest_page_size = page_size + self.model._quest_page_budget = token_budget // page_size # default page budget + self.model._quest_max_page_limit = 1024 * 1024 # arbitraty large size + self.model._quest_skip_layer = 2 + + self.model.iController = InferenceController( + num_layers=config.num_hidden_layers, + num_heads=config.num_attention_heads, + head_dim=config.hidden_size // config.num_attention_heads, + page_size=page_size, + page_budget=self.model._quest_page_budget, + max_seq_len=max_seq_len, # Used for allocating KV Pools + dtype=dtype, + device=device, + ) + + print(f"Quest allocates KV-Cache of {max_seq_len} tokens") + print(f"Token budget is set to {token_budget}") + + def reset_model(self): + """ + Assistant function for cleaning states of KV-Cache, + which prepares for a new conversation. + """ + assert self.model.iController is not None, "Must be called after init." + self.model.iController.clean_states() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + use_cache = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +def enable_quest_attention_eval(model: LlamaForCausalLM, args: Dict): + cache_budget = args["cache_budget"] + page_size = args["page_size"] + max_seq_len = args.get("max_seq_len", model.config.max_position_embeddings) + dtype = args.get("dtype", torch.float16) + device = args.get("device", torch.device("cuda:0")) + model.quest_init(page_size, max_seq_len, cache_budget, dtype, device) diff --git a/quest/models/quest_llama_optimized_bak.py b/quest/models/quest_llama_optimized_bak.py new file mode 100644 index 0000000..5011cd1 --- /dev/null +++ b/quest/models/quest_llama_optimized_bak.py @@ -0,0 +1,998 @@ +# Based on HuggingFace Llama Model: models/llama/modeling_llama.py +# transformers==4.31.0 #TODO + +""" PyTorch LLaMA model.""" +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) + +import quest.quest_utils +from quest.quest_utils import rms_norm_forward +from quest.quest_utils.controller import InferenceController + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1 + ) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + return rms_norm_forward(hidden_states, self.weight, self.variance_epsilon) + + +class LlamaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.pretraining_tp = config.pretraining_tp + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.pretraining_tp > 1: + slice = self.intermediate_size // self.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat( + [F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1 + ) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) + for i in range(self.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +class QuestAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.pretraining_tp = config.pretraining_tp + self.max_position_embeddings = config.max_position_embeddings + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self._init_rope() + + def _init_rope(self): + # rope_theta is default to 1e4, as set in RoPE kernel API. + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings + ) + self.rope_scale = 1.0 + else: + scaling_type = self.config.rope_scaling["type"] + if scaling_type == "linear": + # support for Longchat-v1.5. + self.rope_scale = self.config.rope_scaling["factor"] + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + iController: Optional[quest.quest_utils.InferenceController] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + assert bsz == 1, "QuestAttention only supports batch size 1." + assert hasattr(self, "layer_idx"), "QuestAttention requires layer_idx to inference." + + if self.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [ + F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) + ] + query_states = torch.cat(query_states, dim=-1) + + key_states = [ + F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp) + ] + key_states = torch.cat(key_states, dim=-1) + + value_states = [ + F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp) + ] + value_states = torch.cat(value_states, dim=-1) + + else: + torch.cuda.nvtx.range_push("qkv_proj") + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + torch.cuda.nvtx.range_pop() + + # Not transposed for Append kv cache NHD layout + query_states = query_states.view(q_len, self.num_heads, self.head_dim) + key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) + + torch.cuda.nvtx.range_push("RoPE") + quest.quest_utils.apply_rope_in_place( + query_states, + key_states, + iController.kv_cache.seqlen - q_len, + rope_scale=self.rope_scale, + ) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("append_kv") + # Quest manages KV-Cache internal (with PageAttention) + # Here we do not concat / stack + # We concat after RoPE + quest.quest_utils.append_kv( + key_states, + value_states, + iController, + self.layer_idx, + ) + torch.cuda.nvtx.range_pop() + + # Prefill/Decode kernels is different + if q_len > 1: + torch.cuda.nvtx.range_push("prefill_attn") + attn_output = quest.quest_utils.prefill_forward( + query_states, + iController, + self.layer_idx, + ) + torch.cuda.nvtx.range_pop() + else: + # Skipping layers is controled by PAGE_BUDGET, which is set in LlamaModel. + if iController.need_estimate() == False: + torch.cuda.nvtx.range_push("full_attn") + attn_output = quest.quest_utils.decode_sparse_attn( + query_states, + iController, + self.layer_idx, + iController.kv_indices_without_last, + ) + torch.cuda.nvtx.range_pop() + else: + torch.cuda.nvtx.range_push("estimate") + estimated_attn_score = quest.quest_utils.decode_estimate( + query_states, + iController, + self.layer_idx, + ) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("topk") + quest.quest_utils.decode_topk( + estimated_attn_score, + iController, + ) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("approx_attn") + attn_output = quest.quest_utils.decode_sparse_attn( + query_states, + iController, + self.layer_idx, + iController.topk_dindices_buffer, + ) + torch.cuda.nvtx.range_pop() + + attn_output = attn_output.unsqueeze(0) # unsqueeze the batch dimension + # FlashInfer output is naturally NHD + # Note that we manully control NHD. Should be more general + if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + torch.cuda.nvtx.range_push("o_proj") + if self.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) + attn_output = sum( + [F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)] + ) + else: + attn_output = self.o_proj(attn_output) + torch.cuda.nvtx.range_pop() + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = QuestAttention(config=config, layer_idx=layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + iController: Optional[InferenceController] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + torch.cuda.nvtx.range_push("input_norm") + hidden_states = self.input_layernorm(hidden_states) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("LlamaAttention") + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + iController=iController, + ) + torch.cuda.nvtx.range_pop() + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + torch.cuda.nvtx.range_push("norm") + hidden_states = self.post_attention_layernorm(hidden_states) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push("mlp") + hidden_states = self.mlp(hidden_states) + torch.cuda.nvtx.range_pop() + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LlamaModel): + module.gradient_checkpointing = value + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, i) for i in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + + # Leave Quest controller as uninitialized + self.iController = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # KV-Cache is managed by iController + # if past_key_values is not None: + # past_key_values_length = past_key_values[0][0].shape[2] + # seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + torch.cuda.nvtx.range_push(f"embed") + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + torch.cuda.nvtx.range_pop() + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # Configure Quest Controller + # Prepare indices/indptr for newly appended tokens + assert self.iController is not None, "Please init Quest Controller first." + self.iController.prepare_metadata(seq_length) + + # Skip layers by setting infinite budgets + if self._quest_skip_layer > 0: + self.iController.set_page_budget(self._quest_max_page_limit) + self.iController.begin_forward(seq_length) + + for idx, decoder_layer in enumerate(self.layers): + # Configure regular skipping layers + if idx == self._quest_skip_layer: + self.iController.end_forward() + self.iController.set_page_budget(self._quest_page_budget) + # Avoid the redundant init/copy of metadata + # if previous skip layer does, then skip it again + self.iController.begin_forward(seq_length, updateTensor=(idx == 0)) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # past_key_value = past_key_values[idx] if past_key_values is not None else None + # KV-Cache Managed by ourselves + past_key_value = None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + torch.cuda.nvtx.range_push(f"layer={idx}") + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + iController=self.iController, + ) + torch.cuda.nvtx.range_pop() + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # Configure Quest Controller + self.iController.end_forward() + + torch.cuda.nvtx.range_push("lastnorm") + hidden_states = self.norm(hidden_states) + torch.cuda.nvtx.range_pop() + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.pretraining_tp = config.pretraining_tp + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self._config = config # saved for quest init + # Initialize weights and apply final processing + self.post_init() + + def quest_init( + self, + page_size: int, + max_seq_len: int, + token_budget: int = 512, + dtype: torch.dtype = torch.float16, + device=torch.device("cuda:0"), + ): + """ + Init function for Quest. Must be called before forwarding. + This function allocates all GPU memory for max_seq_len KV-Cache. + """ + assert self.model.iController is None, "Can't init Quest Controller twice." + + config = self._config + self.model._quest_page_size = page_size + self.model._quest_page_budget = token_budget // page_size # default page budget + self.model._quest_max_page_limit = 1024 * 1024 # arbitraty large size + self.model._quest_skip_layer = 2 + + self.model.iController = InferenceController( + num_layers=config.num_hidden_layers, + num_heads=config.num_attention_heads, + head_dim=config.hidden_size // config.num_attention_heads, + page_size=page_size, + page_budget=self.model._quest_page_budget, + max_seq_len=max_seq_len, # Used for allocating KV Pools + dtype=dtype, + device=device, + ) + + print(f"Quest allocates KV-Cache of {max_seq_len} tokens") + print(f"Token budget is set to {token_budget}") + + def reset_model(self): + """ + Assistant function for cleaning states of KV-Cache, + which prepares for a new conversation. + """ + assert self.model.iController is not None, "Must be called after init." + self.model.iController.clean_states() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + torch.cuda.nvtx.range_push("LlamaForCausalLM") + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + torch.cuda.nvtx.range_push("lm_head") + hidden_states = outputs[0] + if self.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split( + self.vocab_size // self.pretraining_tp, dim=0 + ) + logits = [ + F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp) + ] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_pop() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past diff --git a/quest/models/raas_llama_optimized.py b/quest/models/raas_llama_optimized.py new file mode 100644 index 0000000..6932b54 --- /dev/null +++ b/quest/models/raas_llama_optimized.py @@ -0,0 +1,1026 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Based on HuggingFace Llama Model: models/llama/modeling_llama.py +# transformers==4.47.0 +from typing import List, Optional, Tuple, Union, Dict + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.models.llama.configuration_llama import LlamaConfig + +import quest.raas_utils as raas_utils +from quest.raas_utils.controller import InferenceController + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf" +_CONFIG_FOR_DOC = "LlamaConfig" + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + return raas_utils.rms_norm_forward(hidden_states, self.weight, self.variance_epsilon) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[LlamaConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from + (seq_len, num_key_value_heads, head_dim) to (seqlen, num_attention_heads, head_dim) + """ + seq_len, num_key_value_heads, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :].expand(seq_len, num_key_value_heads, n_rep, head_dim) + return hidden_states.reshape(seq_len, num_key_value_heads * n_rep, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.rope_scaling = config.rope_scaling + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + iController: Optional[InferenceController] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + _bsz, q_len, _ = hidden_states.size() + bsz = 1 + + # only get the first few tokens be calculated + # assert bsz == 1, "RaaSAttention only supports batch size 1." + hidden_states_first = hidden_states[:1] + ori_dtype = hidden_states_first.dtype + query_states = self.q_proj(hidden_states_first).to(torch.float16) + key_states = self.k_proj(hidden_states_first).to(torch.float16) + value_states = self.v_proj(hidden_states_first).to(torch.float16) + + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + # Not transposed for Append kv cache NHD layout + query_states = query_states.view(q_len, self.num_heads, self.head_dim) + key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) + + # Hack for GQA: we need to repeat the key and value states to match the number of heads + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + raas_utils.apply_rope_in_place( + query_states, + key_states, + iController.kv_cache.seqlen - q_len, + rope_scale=self.rope_scaling, + rope_theta=self.rope_theta, + ) + + + # RaaS manages KV-Cache internal (with PageAttention) + # Here we do not concat / stack + # We concat after RoPE + raas_utils.append_kv( + key_states, + value_states, + iController, + self.layer_idx, + ) + + if q_len > 1: + attn_output = raas_utils.prefill_forward( + query_states, + iController, + self.layer_idx, + ) + else: + # Skipping layers is controled by PAGE_BUDGET, which is set in LlamaModel. + # if not iController.need_estimate(): + + if iController._page_budget == iController.max_page_limit: + # INFO(raas): skip the first several layers + attn_output = raas_utils.decode_sparse_attn( + query_states, + iController, + self.layer_idx, + iController.kv_indices_without_last, + ) + else: + # breakpoint() + estimated_attn_score = raas_utils.decode_estimate( + query_states, + iController, + self.layer_idx, + ) + + raas_utils.decode_topk( + estimated_attn_score, + iController, + self.layer_idx, + ) + + attn_output = raas_utils.decode_sparse_attn( + query_states, + iController, + self.layer_idx, + # iController.topk_dindices_buffer, + iController.get_saved_pages(self.layer_idx), + ) + + attn_output = attn_output.unsqueeze(0) # unsqueeze the batch dimension + + # FlashInfer output is naturally NHD + # Note that we manully control NHD. Should be more general + if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).to(ori_dtype) + + + attn_output = self.o_proj(attn_output) + # concate attn_output with hidden_states[1:, ...] + attn_output = torch.cat([attn_output, hidden_states[1:]], dim=0) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + iController: Optional[InferenceController] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states[:1] = self.input_layernorm(hidden_states[:1]) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + iController=iController, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states[:1] = self.post_attention_layernorm(hidden_states[:1]) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") + + # Leave RaaS controller as uninitialized + self.iController: Optional[InferenceController] = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + seq_length = inputs_embeds.shape[1] + + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + # Configure RaaS Controller + # Prepare indices/indptr for newly appended tokens + assert self.iController is not None, "Please init RaaS Controller first." + self.iController.prepare_metadata(seq_length) + + # Skip layers by setting infinite budgets + if self._raas_skip_layer > 0: + self.iController.set_page_budget(self._raas_max_page_limit) + self.iController.begin_forward(seq_length) + + for idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + # Configure regular skipping layers + if idx == self._raas_skip_layer: + self.iController.end_forward() + self.iController.set_page_budget(self._raas_page_budget) + # Avoid the redundant init/copy of metadata + # if previous skip layer does, then skip it again + self.iController.begin_forward(seq_length, updateTensor=(idx == 0)) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + iController=self.iController, + ) + + hidden_states = layer_outputs[0] + + # if use_cache: + # next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + self.iController.end_forward() + if seq_length == 1: # decode + self.iController.update_timestamp() + hidden_states[:1] = self.norm(hidden_states[:1]) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + # if return_legacy_cache: + # next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self._config = config # saved for raas init + + # Initialize weights and apply final processing + self.post_init() + + def raas_init( + self, + page_size: int, + max_seq_len: int, + token_budget: int = 512, + dtype: torch.dtype = torch.float16, + device=torch.device("cuda:0"), + ): + """ + Init function for RaaS. Must be called before forwarding. + This function allocates all GPU memory for max_seq_len KV-Cache. + """ + assert self.model.iController is None, "Can't init RaaS Controller twice." + + config = self._config + self.model._raas_page_size = page_size + self.model._raas_page_budget = token_budget // page_size # default page budget + self.model._raas_max_page_limit = 1024 * 1024 # arbitraty large size + self.model._raas_skip_layer = 2 + + self.model.iController = InferenceController( + num_layers=config.num_hidden_layers, + num_heads=config.num_attention_heads, + head_dim=config.hidden_size // config.num_attention_heads, + page_size=page_size, + page_budget=self.model._raas_page_budget, + max_seq_len=max_seq_len, # Used for allocating KV Pools + max_page_limit=self.model._raas_max_page_limit, + dtype=dtype, + device=device, + ) + + print(f"RaaS allocates KV-Cache of {max_seq_len} tokens") + print(f"Token budget is set to {token_budget}") + + def reset_model(self): + """ + Assistant function for cleaning states of KV-Cache, + which prepares for a new conversation. + """ + assert self.model.iController is not None, "Must be called after init." + self.model.iController.clean_states() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + use_cache = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +def enable_raas_attention_eval(model: LlamaForCausalLM, args: Dict): + cache_budget = args["cache_budget"] + page_size = args["page_size"] + max_seq_len = args.get("max_seq_len", model.config.max_position_embeddings) + dtype = args.get("dtype", torch.float16) + device = args.get("device", torch.device("cuda:0")) + model.raas_init(page_size, max_seq_len, cache_budget, dtype, device) diff --git a/quest/quest_ops/CMakeLists.txt b/quest/quest_ops/CMakeLists.txt new file mode 100644 index 0000000..2c158e7 --- /dev/null +++ b/quest/quest_ops/CMakeLists.txt @@ -0,0 +1,47 @@ +cmake_minimum_required(VERSION 3.26.4) # Specify your minimum CMake version + +set(CMAKE_C_COMPILER "/usr/bin/gcc") +set(CMAKE_CXX_COMPILER "/usr/bin/g++") +# set(CMAKE_C_COMPILER "/usr/bin/gcc-11") +# set(CMAKE_CXX_COMPILER "/usr/bin/g++-11") +set(CMAKE_C_STANDARD 17) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_STANDARD 17) + +if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + set(CMAKE_CUDA_ARCHITECTURES native) +endif() + +# ------------- configure rapids-cmake --------------# +include(${CMAKE_SOURCE_DIR}/cmake/fetch_rapids.cmake) +include(rapids-cmake) +include(rapids-cpm) +include(rapids-cuda) +include(rapids-export) +include(rapids-find) + +project(_quest_kernels LANGUAGES CUDA CXX) # Replace with your project's name + +# ------------- configure raft -----------------# +rapids_cpm_init() +include(${CMAKE_SOURCE_DIR}/cmake/get_raft.cmake) + +# Check: https://stackoverflow.com/questions/68401650/how-can-i-make-a-pytorch-extension-with-cmake +# Fix linking error: https://github.com/pytorch/pytorch/issues/108041 +find_package(Python REQUIRED COMPONENTS Interpreter Development) +find_package(Torch REQUIRED) +find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib") + +# Try combine pybind +# Check: https://qiita.com/syoyo/items/c3e8e6e5c3e2d69c2325 +add_subdirectory(${CMAKE_SOURCE_DIR}/../../kernels/3rdparty/pybind ${CMAKE_BINARY_DIR}/pybind11) + +file(GLOB PYTORCH_SOURCES "csrc/*.cu") +pybind11_add_module(_quest_kernels MODULE ${PYTORCH_CPP_SOURCES} ${PYTORCH_SOURCES}) + +target_compile_definitions(_quest_kernels PRIVATE -DBSK_TORCH_CHECK) # Enable Torch Tensor Dimension Check +target_include_directories(_quest_kernels PRIVATE ${CMAKE_SOURCE_DIR}/../../kernels/include) +target_include_directories(_quest_kernels PRIVATE ${CMAKE_SOURCE_DIR}/../../kernels/3rdparty/flashinfer/include) +target_include_directories(_quest_kernels PRIVATE ${CMAKE_SOURCE_DIR}/../../kernels/3rdparty/pybind/include) +target_compile_options(_quest_kernels PRIVATE $<$:--expt-extended-lambda --expt-relaxed-constexpr>) +target_link_libraries(_quest_kernels PRIVATE ${TORCH_LIBRARIES} raft::raft Python::Python pybind11::module ${TORCH_PYTHON_LIBRARY}) \ No newline at end of file diff --git a/quest/ops/cmake/fetch_rapids.cmake b/quest/quest_ops/cmake/fetch_rapids.cmake similarity index 100% rename from quest/ops/cmake/fetch_rapids.cmake rename to quest/quest_ops/cmake/fetch_rapids.cmake diff --git a/quest/ops/cmake/get_raft.cmake b/quest/quest_ops/cmake/get_raft.cmake similarity index 100% rename from quest/ops/cmake/get_raft.cmake rename to quest/quest_ops/cmake/get_raft.cmake diff --git a/quest/ops/csrc/approx_attn.cu b/quest/quest_ops/csrc/approx_attn.cu similarity index 100% rename from quest/ops/csrc/approx_attn.cu rename to quest/quest_ops/csrc/approx_attn.cu diff --git a/quest/ops/csrc/batch_prefill.cu b/quest/quest_ops/csrc/batch_prefill.cu similarity index 100% rename from quest/ops/csrc/batch_prefill.cu rename to quest/quest_ops/csrc/batch_prefill.cu diff --git a/quest/quest_ops/csrc/bsk_ops.cu b/quest/quest_ops/csrc/bsk_ops.cu new file mode 100644 index 0000000..85b553f --- /dev/null +++ b/quest/quest_ops/csrc/bsk_ops.cu @@ -0,0 +1,20 @@ +#include +#include "bsk_ops.h" + +PYBIND11_MODULE(_quest_kernels, m) { + m.def("apply_rope_in_place", &apply_rope_in_place, "Apply RoPE on Q/K in place."); + m.def("rms_norm_forward", &rms_norm_forward, "rms_norm_forward by cutlass"); + m.def("topk_filtering", &topk_filtering, "Top-k filtering operator"); + m.def("estimate_attn_score", &estimate_attn_score, "Estimate Attention Score operator"); + m.def("append_kv_cache_prefill", &append_kv_cache_prefill, "Append KV-Cache Prefill operator"); + m.def("append_kv_cache_decode", &append_kv_cache_decode, "Append KV-Cache Decode operator"); + m.def("prefill_with_paged_kv_cache", + &prefill_with_paged_kv_cache, + "Multi-request batch prefill with paged KV-Cache operator"); + py::class_( + m, "BatchDecodeWithPagedKVCachePyTorchWrapper") + .def(py::init(&BatchDecodeWithPagedKVCachePyTorchWrapper::Create)) + .def("begin_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward) + .def("end_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward) + .def("forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::Forward); +} \ No newline at end of file diff --git a/quest/ops/csrc/bsk_ops.h b/quest/quest_ops/csrc/bsk_ops.h similarity index 100% rename from quest/ops/csrc/bsk_ops.h rename to quest/quest_ops/csrc/bsk_ops.h diff --git a/quest/ops/csrc/estimate.cu b/quest/quest_ops/csrc/estimate.cu similarity index 100% rename from quest/ops/csrc/estimate.cu rename to quest/quest_ops/csrc/estimate.cu diff --git a/quest/ops/csrc/page.cu b/quest/quest_ops/csrc/page.cu similarity index 100% rename from quest/ops/csrc/page.cu rename to quest/quest_ops/csrc/page.cu diff --git a/quest/ops/csrc/pytorch_extension_utils.h b/quest/quest_ops/csrc/pytorch_extension_utils.h similarity index 100% rename from quest/ops/csrc/pytorch_extension_utils.h rename to quest/quest_ops/csrc/pytorch_extension_utils.h diff --git a/quest/ops/csrc/rms_norm.cu b/quest/quest_ops/csrc/rms_norm.cu similarity index 100% rename from quest/ops/csrc/rms_norm.cu rename to quest/quest_ops/csrc/rms_norm.cu diff --git a/quest/ops/csrc/topk.cu b/quest/quest_ops/csrc/topk.cu similarity index 100% rename from quest/ops/csrc/topk.cu rename to quest/quest_ops/csrc/topk.cu diff --git a/quest/ops/setup.sh b/quest/quest_ops/setup.sh similarity index 100% rename from quest/ops/setup.sh rename to quest/quest_ops/setup.sh diff --git a/quest/quest_utils/__init__.py b/quest/quest_utils/__init__.py new file mode 100644 index 0000000..420e6d6 --- /dev/null +++ b/quest/quest_utils/__init__.py @@ -0,0 +1,276 @@ +import torch +import math +from typing import Optional + +import quest._quest_kernels as _kernels +from quest.quest_utils.utils import TensorLayout +from quest.quest_utils.kv_cache import KvCache +from quest.quest_utils.controller import InferenceController +from quest.quest_utils.decode_wrapper import BatchDecodeWithPagedKVCacheWrapper + +__all__ = [ + 'TensorLayout', + 'KvCache', + 'InferenceController', + "BatchDecodeWithPagedKVCacheWrapper", + "append_kv", + "prefill_forward", + "decode_estimate", + "decode_topk", + "decode_sparse_attn", + "rms_norm_forward", + "apply_rope_in_place", +] + +def apply_rope_in_place( + q: torch.Tensor, + k: torch.Tensor, + past_kv_len: int, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, +): + """ + Semantics of `apply_rope_in_place`: + Apply RoPE (Relative Positional Encoding) in-place. + On q, k which is generated by GEMM. Layout is naturally NHD. + + Args: + q: Shape: `[N, H, D]`. + k: Shape: `[N, H, D]`. + past_kv_len: Length of past KV cache. Used to calculate frequency. + """ + if rope_scale is None: + rope_scale = 1.0 + if rope_theta is None: + rope_theta = 1e4 + _kernels.apply_rope_in_place( + q, + k, + past_kv_len, + rope_scale, + rope_theta, + ) + +def rms_norm_forward( + input: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> torch.Tensor: + o = torch.empty_like(input, dtype=input.dtype, device=input.device) + f = _kernels.rms_norm_forward + f( + input, + weight, + o, + epsilon, + ) + return o + +def append_kv( + k: torch.Tensor, + v: torch.Tensor, + iController: InferenceController, + layer_idx: int, +): + """ + Semantics of `append_kv`: + Append new generated k/v into kv cache and meta data cache. + Automatically dispatch to Prefill / Decode Kernel + + Notations for shapes: + `B`: batch size + `N`: number of heads + `D`: head dimension + `L`: number of layers + `MAXLEN`: maximum length of the KV cache + + Args: + k: Shape: `[B, N, D]`. Key projection (`X @ W_k`). + v: Shape: `[B, N, D]`. Value projection (`X @ W_v`). + iController: InferenceController object, which contains all needed information. + layer_idx: Layer index of the KV cache. + """ + seq_len = k.size(0) + if seq_len > 1: + _kernels.append_kv_cache_prefill( + k, + v, + iController.kv_cache.buf_layer(layer_idx), + iController.kv_indices_with_last, + iController.kv_indptr_for_append, + iController.kv_cache.last_page_len, + iController.kv_last_page_idx, + iController.metadata_cache.buf_layer(layer_idx), + iController.metadata_indices, + iController.metadata_indptr_for_append, + iController.metadata_cache.last_page_len, + iController.metadata_last_page_idx, + iController.layout + ) + else: + _kernels.append_kv_cache_decode( + k, + v, + iController.kv_cache.buf_layer(layer_idx), + iController.kv_indices_with_last, + iController.kv_indptr_for_append, + iController.kv_cache.last_page_len, + iController.kv_last_page_idx, + iController.metadata_cache.buf_layer(layer_idx), + iController.metadata_indices, + iController.metadata_indptr_for_append, + iController.metadata_cache.last_page_len, + iController.metadata_last_page_idx, + iController.layout + ) + +def prefill_forward( + q: torch.Tensor, + iController: InferenceController, + layer_idx: int, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, +) -> torch.Tensor: + """ + Semantics of `prefill_forward`: + New genrated K/Vs are already in the kv cache and meta data cache (well-maintained). + Perform FlashInfer Self-Attention with Casual Attention. + Note that we not have position shift and current version not support Prefill Optimization. + + Notations for shapes: + `B`: batch size + `N`: number of heads + `D`: head dimension + `L`: number of layers + `MAXLEN`: maximum length of the KV cache + + Args: + q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). + iController: InferenceController object, which contains all needed information. + layer_idx: Layer index of the KV cache. + """ + if rope_scale is None: + rope_scale = 1.0 + if rope_theta is None: + rope_theta = 1e4 + + f = _kernels.prefill_with_paged_kv_cache + o = f( + q, + iController.kv_cache.buf_layer(layer_idx), + iController.kv_indices_with_last, + iController.kv_cache.last_page_len, + True, # Casual + iController.layout, + False, # FP16 Accumulator for 4090 + rope_scale, + rope_theta, + ) + return o + +def decode_estimate( + q: torch.Tensor, + iController: InferenceController, + layer_idx: int, +) -> torch.Tensor: + """ + Semantics of `decode_estimate`: + When decoding, estimate the attention score for each page. + + Notations for shapes: + `B`: batch size + `N`: number of heads + `D`: head dimension + `L`: number of layers + `MAXLEN`: maximum length of the KV cache + + Args: + q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). + iController: InferenceController object, which contains all needed information. + layer_idx: Layer index of the KV cache. + """ + f = _kernels.estimate_attn_score + # (iController.metadata_cache.seqlen - 1) is manually excluding the last elements, which is the current page. + o = torch.empty((iController.num_heads, iController.metadata_cache.seqlen - 1), dtype=q.dtype, device=q.device) + f( + q, + o, + iController.metadata_cache.buf_layer(layer_idx), + iController.metadata_indices, + iController.metadata_indptr_for_append, + iController.metadata_cache.last_page_len, # One entry delta is considered by kernel-level implementation + iController.metadata_last_page_idx, + iController.layout, + ) + return o + +def decode_topk( + estimated_attn_score: torch.Tensor, + iController: InferenceController, +): + """ + Semantics of `decode_topk`: + select top-k pages with highest attention score. + + Notations for shapes: + `B`: batch size + `N`: number of heads + `D`: head dimension + `L`: number of layers + `MAXLEN`: maximum length of the KV cache + + Args: + q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). + iController: InferenceController object, which contains all needed information. + layer_idx: Layer index of the KV cache. + """ + # excluding the last page + page_budet = iController.inference_page_budget - 1 + f = _kernels.topk_filtering + f( + estimated_attn_score, + iController.kv_indices_without_last, + iController.topk_dout_buffer, + iController.topk_dindices_buffer, + iController.topk_buf, + page_budet, + ) + +def decode_sparse_attn( + q: torch.Tensor, + iController: InferenceController, + layer_idx: int, + topk_indices: torch.Tensor, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, +) -> torch.Tensor: + """ + Semantics of `decode_sparse_attn`: + Excute self-attention only on the selected pages (Top-k output) + + Notations for shapes: + `B`: batch size + `N`: number of heads + `D`: head dimension + `L`: number of layers + `MAXLEN`: maximum length of the KV cache + + Args: + q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). + iController: InferenceController object, which contains all needed information. + layer_idx: Layer index of the KV cache. + topk_indices: Shape: `[N, page_budget-1]`. Top-k indices. + """ + o = torch.empty_like(q, dtype=q.dtype, device=q.device) + iController._decode_handler.forward( + q, + o, + iController.kv_cache.buf_layer(layer_idx), + topk_indices, + iController.kv_indptr_for_approx_decode, + iController.kv_cache.last_page_len, + iController.kv_last_page_idx, + rope_scale, + rope_theta, + ) + return o \ No newline at end of file diff --git a/quest/utils/controller.py b/quest/quest_utils/controller.py similarity index 97% rename from quest/utils/controller.py rename to quest/quest_utils/controller.py index 135b8d2..0e307d7 100644 --- a/quest/utils/controller.py +++ b/quest/quest_utils/controller.py @@ -1,6 +1,6 @@ -from quest.utils.decode_wrapper import BatchDecodeWithPagedKVCacheWrapper -from quest.utils.kv_cache import KvCache -from quest.utils.utils import TensorLayout +from quest.quest_utils.decode_wrapper import BatchDecodeWithPagedKVCacheWrapper +from quest.quest_utils.kv_cache import KvCache +from quest.quest_utils.utils import TensorLayout import torch diff --git a/quest/quest_utils/decode_wrapper.py b/quest/quest_utils/decode_wrapper.py new file mode 100644 index 0000000..c39db2e --- /dev/null +++ b/quest/quest_utils/decode_wrapper.py @@ -0,0 +1,82 @@ +import torch +from typing import Optional + +import quest._quest_kernels as _kernels +from quest.quest_utils.utils import TensorLayout + +def _check_kv_layout(kv_layout: str): + if not hasattr(TensorLayout, kv_layout): + raise KeyError("Invalide kv_layout {}".format(kv_layout)) + +class BatchDecodeWithPagedKVCacheWrapper: + r"""Wrapper class for batch_decode_with_paged_kv_cache kernel. + + To accelerate computation, FlashInfer's batch decode operators creates some + auxiliary data structures, these data structures can be reused across multiple + batch decode calls (e.g. different Transformer layers). This wrapper class manages + the lifecycle of these data structures. + """ + + def __init__(self, kv_layout: str = "NHD"): + _check_kv_layout(kv_layout) + self.kv_layout = kv_layout + self._wrapper = _kernels.BatchDecodeWithPagedKVCachePyTorchWrapper( + getattr(TensorLayout, kv_layout) + ) + + def begin_forward( + self, + indptr: torch.Tensor, # [0, Page_budget - 1], once per forward for all layers + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + page_size: int, + data_type, + ): + r"""The begin_forward method should be called before any batch decode calls, + auxiliary data structures will be created during this call and cached for + multiple forward calls. + """ + + # NOTE(Zihao): the following tensor acts as placeholder to pass dtype info + empty_data = torch.empty(0, dtype=data_type) + self._wrapper.begin_forward( + indptr, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + empty_data, + ) + + def end_forward(self): + r"""The end_forward method can clear the cached data structures.""" + self._wrapper.end_forward() + + def forward( + self, + q: torch.Tensor, + o: torch.Tensor, + paged_kv_data: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_last_page_len: int, + paged_kv_last_page_idx: int, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, + ): + if rope_scale is None: + rope_scale = 1.0 + if rope_theta is None: + rope_theta = 1e4 + self._wrapper.forward( + q, + o, + paged_kv_data, + paged_kv_indices, + paged_kv_indptr, + paged_kv_last_page_len, + paged_kv_last_page_idx, + rope_scale, + rope_theta, + ) \ No newline at end of file diff --git a/quest/utils/kv_cache.py b/quest/quest_utils/kv_cache.py similarity index 98% rename from quest/utils/kv_cache.py rename to quest/quest_utils/kv_cache.py index ed4d939..d9fc835 100644 --- a/quest/utils/kv_cache.py +++ b/quest/quest_utils/kv_cache.py @@ -1,7 +1,7 @@ # This file is modified from Punica Project # Check ref: https://github.com/punica-ai/punica -from quest.utils.utils import TensorLayout +from quest.quest_utils.utils import TensorLayout import torch class KvPool: diff --git a/quest/utils/utils.py b/quest/quest_utils/utils.py similarity index 100% rename from quest/utils/utils.py rename to quest/quest_utils/utils.py diff --git a/quest/ops/CMakeLists.txt b/quest/raas_ops/CMakeLists.txt similarity index 61% rename from quest/ops/CMakeLists.txt rename to quest/raas_ops/CMakeLists.txt index 90e3fcb..ffa98e7 100644 --- a/quest/ops/CMakeLists.txt +++ b/quest/raas_ops/CMakeLists.txt @@ -20,7 +20,7 @@ include(rapids-cuda) include(rapids-export) include(rapids-find) -project(_kernels LANGUAGES CUDA CXX) # Replace with your project's name +project(_raas_kernels LANGUAGES CUDA CXX) # Replace with your project's name # ------------- configure raft -----------------# rapids_cpm_init() @@ -37,11 +37,11 @@ find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib add_subdirectory(${CMAKE_SOURCE_DIR}/../../kernels/3rdparty/pybind ${CMAKE_BINARY_DIR}/pybind11) file(GLOB PYTORCH_SOURCES "csrc/*.cu") -pybind11_add_module(_kernels MODULE ${PYTORCH_CPP_SOURCES} ${PYTORCH_SOURCES}) - -target_compile_definitions(_kernels PRIVATE -DBSK_TORCH_CHECK) # Enable Torch Tensor Dimension Check -target_include_directories(_kernels PRIVATE ${CMAKE_SOURCE_DIR}/../../kernels/include) -target_include_directories(_kernels PRIVATE ${CMAKE_SOURCE_DIR}/../../kernels/3rdparty/flashinfer/include) -target_include_directories(_kernels PRIVATE ${CMAKE_SOURCE_DIR}/../../kernels/3rdparty/pybind/include) -target_compile_options(_kernels PRIVATE $<$:--expt-extended-lambda --expt-relaxed-constexpr>) -target_link_libraries(_kernels PRIVATE ${TORCH_LIBRARIES} raft::raft Python::Python pybind11::module ${TORCH_PYTHON_LIBRARY}) \ No newline at end of file +pybind11_add_module(_raas_kernels MODULE ${PYTORCH_CPP_SOURCES} ${PYTORCH_SOURCES}) + +target_compile_definitions(_raas_kernels PRIVATE -DBSK_TORCH_CHECK) # Enable Torch Tensor Dimension Check +target_include_directories(_raas_kernels PRIVATE ${CMAKE_SOURCE_DIR}/../../kernels/include) +target_include_directories(_raas_kernels PRIVATE ${CMAKE_SOURCE_DIR}/../../kernels/3rdparty/flashinfer/include) +target_include_directories(_raas_kernels PRIVATE ${CMAKE_SOURCE_DIR}/../../kernels/3rdparty/pybind/include) +target_compile_options(_raas_kernels PRIVATE $<$:--expt-extended-lambda --expt-relaxed-constexpr>) +target_link_libraries(_raas_kernels PRIVATE ${TORCH_LIBRARIES} raft::raft Python::Python pybind11::module ${TORCH_PYTHON_LIBRARY}) \ No newline at end of file diff --git a/quest/raas_ops/cmake/fetch_rapids.cmake b/quest/raas_ops/cmake/fetch_rapids.cmake new file mode 100644 index 0000000..15b6c43 --- /dev/null +++ b/quest/raas_ops/cmake/fetch_rapids.cmake @@ -0,0 +1,21 @@ +# ============================================================================= +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +# Use this variable to update RAPIDS and RAFT versions +set(RAPIDS_VERSION "24.02") + +if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) + file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake + ${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) +endif() +include(${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) diff --git a/quest/raas_ops/cmake/get_raft.cmake b/quest/raas_ops/cmake/get_raft.cmake new file mode 100644 index 0000000..6128b5c --- /dev/null +++ b/quest/raas_ops/cmake/get_raft.cmake @@ -0,0 +1,63 @@ +# ============================================================================= +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +# Use RAPIDS_VERSION from cmake/thirdparty/fetch_rapids.cmake +set(RAFT_VERSION "${RAPIDS_VERSION}") +set(RAFT_FORK "rapidsai") +set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") + +function(find_and_configure_raft) + set(oneValueArgs VERSION FORK PINNED_TAG COMPILE_LIBRARY ENABLE_NVTX ENABLE_MNMG_DEPENDENCIES) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + + set(RAFT_COMPONENTS "") + if(PKG_COMPILE_LIBRARY) + string(APPEND RAFT_COMPONENTS " compiled") + endif() + + if(PKG_ENABLE_MNMG_DEPENDENCIES) + string(APPEND RAFT_COMPONENTS " distributed") + endif() + + #----------------------------------------------------- + # Invoke CPM find_package() + #----------------------------------------------------- + rapids_cpm_find(raft ${PKG_VERSION} + GLOBAL_TARGETS raft::raft + BUILD_EXPORT_SET raft-template-exports + INSTALL_EXPORT_SET raft-template-exports + COMPONENTS ${RAFT_COMPONENTS} + CPM_ARGS + GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git + GIT_TAG ${PKG_PINNED_TAG} + SOURCE_SUBDIR cpp + OPTIONS + "BUILD_TESTS OFF" + "BUILD_PRIMS_BENCH OFF" + "BUILD_ANN_BENCH OFF" + "RAFT_NVTX ${ENABLE_NVTX}" + "RAFT_COMPILE_LIBRARY ${PKG_COMPILE_LIBRARY}" + ) +endfunction() + +# Change pinned tag here to test a commit in CI +# To use a different RAFT locally, set the CMake variable +# CPM_raft_SOURCE=/path/to/local/raft +find_and_configure_raft(VERSION ${RAFT_VERSION}.00 + FORK ${RAFT_FORK} + PINNED_TAG ${RAFT_PINNED_TAG} + COMPILE_LIBRARY ON + ENABLE_MNMG_DEPENDENCIES OFF + ENABLE_NVTX OFF +) diff --git a/quest/raas_ops/csrc/approx_attn.cu b/quest/raas_ops/csrc/approx_attn.cu new file mode 100644 index 0000000..314f6f9 --- /dev/null +++ b/quest/raas_ops/csrc/approx_attn.cu @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + Modified from FlashInfer PyTorch API. + Check: https://github.com/flashinfer-ai/flashinfer/blob/main/python/csrc/batch_decode.cu +*/ + +#include "bsk_ops.h" +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(torch::Tensor indptr, + unsigned int num_qo_heads, + unsigned int num_kv_heads, + unsigned int head_dim, + unsigned int page_size, + torch::Tensor empty_data) { + constexpr size_t batch_size = 1; + + #ifdef BSK_TORCH_CHECK + CHECK_CONTIGUOUS(indptr); + CHECK_DIM(1, indptr); + CHECK_EQ(indptr.scalar_type(), torch::kInt32); + #endif + + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] { + SWITCH_LAYOUT(kv_layout_, KV_LAYOUT, { + cudaError_t status = + handler_.BeginForward( + static_cast(indptr.data_ptr()), + batch_size, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + RotaryMode::kNone); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }) + }); + + TORCH_CHECK(success, + "BatchDecodeWithPagedKVCache failed to dispatch with dtype ", + empty_data.scalar_type()); +} + +void BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward() { + handler_.EndForward(); +} + +void +BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(torch::Tensor q, + torch::Tensor o, + torch::Tensor paged_kv_data, + torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_indptr, // [1, Page_budget] + unsigned int paged_kv_last_page_len, + unsigned int paged_kv_last_page_idx, + float rope_scale, + float rope_theta) { + constexpr size_t batch_size = 1; + + #ifdef BSK_TORCH_CHECK + CHECK_INPUT(q); + CHECK_INPUT(paged_kv_data); + CHECK_INPUT(paged_kv_indices); + CHECK_DIM(3, q); // (B, H_qo, D) + CHECK_DIM(2, paged_kv_indices); // (num_heads, page_budget - 1) + // (num_max_pages, 2, H_kv, page_size, head_dim) for HND + // (num_max_pages, 2, page_size, H_kv, head_dim) for NHD + CHECK_DIM(5, paged_kv_data); + #endif + + int64_t num_qo_heads = q.size(1); + int64_t head_dim = q.size(2); + int64_t num_kv_heads, page_size; + // This is the stride of the paged_kv_indices tensor + // actual page budget is page_budget + 1 + int64_t page_budget = paged_kv_indices.size(1); + + if(kv_layout_ == QKVLayout::kHND) { + num_kv_heads = paged_kv_data.size(2); + page_size = paged_kv_data.size(3); + } else { + page_size = paged_kv_data.size(2); + num_kv_heads = paged_kv_data.size(3); + } + + #ifdef BSK_TORCH_CHECK + CHECK_EQ(paged_kv_indices.size(0), num_qo_heads); + CHECK_EQ(paged_kv_data.size(1), 2); + CHECK_EQ(paged_kv_data.size(4), head_dim); + CHECK_EQ(paged_kv_indices.scalar_type(), torch::kInt32); + #endif + + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + SWITCH_LAYOUT(kv_layout_, KV_LAYOUT, { + paged_kv_t paged_kv( + num_kv_heads, + page_size, + head_dim, + batch_size, + page_budget, + paged_kv_last_page_len, + paged_kv_last_page_idx, + static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr())); + + cudaError_t status = + BatchDecodeWithPagedKVCacheWrapper(&handler_, + static_cast(q.data_ptr()), + paged_kv, + static_cast(o.data_ptr()), + /*lse=*/nullptr, + num_qo_heads, + RotaryMode::kNone, + rope_scale, + rope_theta, + /*stream=*/nullptr); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + }); + return true; + }); + + TORCH_CHECK( + success, "BatchDecodeWithPagedKVCache failed to dispatch with dtype ", q.scalar_type()); +} \ No newline at end of file diff --git a/quest/raas_ops/csrc/batch_prefill.cu b/quest/raas_ops/csrc/batch_prefill.cu new file mode 100644 index 0000000..7d15669 --- /dev/null +++ b/quest/raas_ops/csrc/batch_prefill.cu @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + Modified from FlashInfer project. + Check: https://github.com/flashinfer-ai/flashinfer/blob/main/python/csrc/batch_prefill.cu +*/ + +#include "bsk_ops.h" +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +torch::Tensor prefill_with_paged_kv_cache(torch::Tensor q, + torch::Tensor kv_data, + torch::Tensor kv_indices, + unsigned int kv_last_page_len, + bool causal, + unsigned int layout, + bool allow_fp16_qk_reduction, + float rope_scale, + float rope_theta) { + constexpr size_t batch_size = 1; + + #ifdef BSK_TORCH_CHECK + CHECK_INPUT(q); // [sum(extend_len), num_qo_heads, head_dim] + // [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND + // [max_num_pages, 2, page_size, num_kv_heads, head_dim] for HND + CHECK_INPUT(kv_data); + CHECK_INPUT(kv_indices); // [sum(seq_len)] + #endif + + // bsk only utilizes flashinfer for bsz=1. Therefore we can infer some parameters. + torch::Tensor q_indptr = torch::tensor({0, static_cast(q.size(0))}, kv_indices.options()); + torch::Tensor kv_indptr = torch::tensor({0, static_cast(kv_indices.size(0))}, kv_indices.options()); + + #ifdef BSK_TORCH_CHECK + CHECK_DIM(3, q); + CHECK_DIM(5, kv_data) + CHECK_DIM(1, q_indptr); + CHECK_DIM(1, kv_indptr); + CHECK_DIM(1, kv_indices); + CHECK_EQ(q_indptr.size(0), kv_indptr.size(0)); + CHECK_EQ(kv_indices.scalar_type(), torch::kInt32); + CHECK_EQ(q.size(2), kv_data.size(4)); + #endif + + QKVLayout kv_layout = QKVLayout(layout); + unsigned int page_size, num_kv_heads; + if(kv_layout == QKVLayout::kHND) { + num_kv_heads = kv_data.size(2); + page_size = kv_data.size(3); + } else { + page_size = kv_data.size(2); + num_kv_heads = kv_data.size(3); + } + unsigned int head_dim = q.size(2); + unsigned int num_qo_heads = q.size(1); + + auto o = torch::empty_like(q, q.options()); + + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + SWITCH_LAYOUT(kv_layout, KV_LAYOUT, { + paged_kv_t paged_kv( + num_kv_heads, + page_size, + head_dim, + batch_size, + 0, + kv_last_page_len, + kv_indices[-1].item(), + static_cast(kv_data.data_ptr()), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr())); + + cudaError_t status = + BatchPrefillWithPagedKVCache(static_cast(q.data_ptr()), + static_cast(q_indptr.data_ptr()), + paged_kv, + static_cast(o.data_ptr()), + /*tmp=*/nullptr, + /*lse=*/nullptr, + num_qo_heads, + causal, + RotaryMode::kNone, + allow_fp16_qk_reduction, + rope_scale, + rope_theta); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCache failed with error code ", + cudaGetErrorString(status)); + }); + return true; + }); + + TORCH_CHECK( + success, "BatchPrefillWithPagedKVCache failed to dispatch with dtype ", q.scalar_type()); + + return o; +} \ No newline at end of file diff --git a/quest/ops/csrc/bsk_ops.cu b/quest/raas_ops/csrc/bsk_ops.cu similarity index 96% rename from quest/ops/csrc/bsk_ops.cu rename to quest/raas_ops/csrc/bsk_ops.cu index 300885f..a8494f0 100644 --- a/quest/ops/csrc/bsk_ops.cu +++ b/quest/raas_ops/csrc/bsk_ops.cu @@ -1,7 +1,7 @@ #include #include "bsk_ops.h" -PYBIND11_MODULE(_kernels, m) { +PYBIND11_MODULE(_raas_kernels, m) { m.def("apply_rope_in_place", &apply_rope_in_place, "Apply RoPE on Q/K in place."); m.def("rms_norm_forward", &rms_norm_forward, "rms_norm_forward by cutlass"); m.def("topk_filtering", &topk_filtering, "Top-k filtering operator"); diff --git a/quest/raas_ops/csrc/bsk_ops.h b/quest/raas_ops/csrc/bsk_ops.h new file mode 100644 index 0000000..b0bec29 --- /dev/null +++ b/quest/raas_ops/csrc/bsk_ops.h @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include + +#include "decode/decode_handler.cuh" +#include "prefill/prefill.cuh" +#include "topk/decode_select_k.cuh" + +void apply_rope_in_place(torch::Tensor q, + torch::Tensor k, + unsigned int past_kv_len, + float rope_scale, + float rope_theta); + +void rms_norm_forward(torch::Tensor input, + torch::Tensor weight, + torch::Tensor output, + float epsilon); + +void topk_filtering(torch::Tensor estimated_value, + torch::Tensor estimated_indices, + torch::Tensor d_out, + torch::Tensor indices_out, + torch::Tensor buf, + unsigned int page_budget); + +void estimate_attn_score(torch::Tensor q, + torch::Tensor o, + torch::Tensor metadata_data, + torch::Tensor metadata_indices, + torch::Tensor metadata_indptr, + unsigned int metadata_last_page_len, + unsigned int metadata_last_page_idx, + unsigned int layout); + +void append_kv_cache_prefill(torch::Tensor k, + torch::Tensor v, + torch::Tensor kv_data, + torch::Tensor kv_indices, + torch::Tensor kv_indptr, + unsigned int kv_last_page_len, + unsigned int kv_last_page_idx, + torch::Tensor metadata_data, + torch::Tensor metadata_indices, + torch::Tensor metadata_indptr, + unsigned int metadata_last_page_len, + unsigned int metadata_last_page_idx, + unsigned int layout); + +void append_kv_cache_decode(torch::Tensor k, + torch::Tensor v, + torch::Tensor kv_data, + torch::Tensor kv_indices, + torch::Tensor kv_indptr, + unsigned int kv_last_page_len, + unsigned int kv_last_page_idx, + torch::Tensor metadata_data, + torch::Tensor metadata_indices, + torch::Tensor metadata_indptr, + unsigned int metadata_last_page_len, + unsigned int metadata_last_page_idx, + unsigned int layout); + +torch::Tensor prefill_with_paged_kv_cache(torch::Tensor q, + torch::Tensor kv_data, + torch::Tensor kv_indices, + unsigned int kv_last_page_len, + bool causal, + unsigned int layout, + bool allow_fp16_qk_reduction, + float rope_scale, + float rope_theta); + +class BatchDecodeWithPagedKVCachePyTorchWrapper { +public: + static BatchDecodeWithPagedKVCachePyTorchWrapper Create(unsigned int layout) { + return BatchDecodeWithPagedKVCachePyTorchWrapper(layout); + } + void BeginForward(torch::Tensor indptr, + unsigned int num_qo_heads, + unsigned int num_kv_heads, + unsigned int head_dim, + unsigned int page_size, + torch::Tensor empty_data); + + void EndForward(); + + void Forward(torch::Tensor q, + torch::Tensor o, + torch::Tensor paged_kv_data, + torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_indptr, + unsigned int paged_kv_last_page_len, + unsigned int paged_kv_last_page_idx, + float rope_scale, + float rope_theta); + +private: + BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout) + : kv_layout_(flashinfer::QKVLayout(layout)) { } + flashinfer::BatchDecodeHandler handler_; + flashinfer::QKVLayout kv_layout_; +}; \ No newline at end of file diff --git a/quest/raas_ops/csrc/estimate.cu b/quest/raas_ops/csrc/estimate.cu new file mode 100644 index 0000000..aecb5bd --- /dev/null +++ b/quest/raas_ops/csrc/estimate.cu @@ -0,0 +1,84 @@ +#include "bsk_ops.h" +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +void estimate_attn_score(torch::Tensor q, + torch::Tensor o, + torch::Tensor metadata_data, + torch::Tensor metadata_indices, + torch::Tensor metadata_indptr, + unsigned int metadata_last_page_len, + unsigned int metadata_last_page_idx, + unsigned int layout) { + constexpr size_t batch_size = 1; + + #ifdef BSK_TORCH_CHECK + CHECK_INPUT(q); // [1, num_heads, head_dim] + // (num_max_pages, 2, H_kv, page_size, head_dim) for HND + // (num_max_pages, 2, page_size, H_kv, head_dim) for NHD + CHECK_INPUT(metadata_data); + CHECK_INPUT(metadata_indices); + + CHECK_DIM(3, q); + CHECK_DIM(5, metadata_data); + CHECK_DIM(1, metadata_indices); + + CHECK_EQ(q.size(0), 1); + CHECK_EQ(metadata_indices.scalar_type(), torch::kInt32); + #endif + + size_t num_heads = q.size(1); + size_t head_dim = q.size(2); + size_t page_size; + + QKVLayout kv_layout = static_cast(layout); + if(kv_layout == QKVLayout::kHND) { + page_size = metadata_data.size(3); + #ifdef BSK_TORCH_CHECK + CHECK_EQ(metadata_data.size(2), num_heads); + CHECK_EQ(metadata_data.size(4), head_dim); + #endif + } else { + page_size = metadata_data.size(2); + #ifdef BSK_TORCH_CHECK + CHECK_EQ(metadata_data.size(3), num_heads); + CHECK_EQ(metadata_data.size(4), head_dim); + #endif + } + + // size_t output_len = (metadata_indices.size(0) - 1) * page_size + metadata_last_page_len - 1; + // torch::Tensor o = torch::empty( + // {static_cast(num_heads), static_cast(output_len)}, q.options()); + + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + SWITCH_LAYOUT(kv_layout, KV_LAYOUT, { + paged_kv_t paged_kv( + num_heads, + page_size, + head_dim, + batch_size, + 0, + metadata_last_page_len, + metadata_last_page_idx, + static_cast(metadata_data.data_ptr()), + static_cast(metadata_indices.data_ptr()), + static_cast(metadata_indptr.data_ptr())); + cudaError_t status = + MaxPossibleSampleWithPagedKVCache(static_cast(q.data_ptr()), + paged_kv, + static_cast(o.data_ptr()), + num_heads, + /*rotary_mode*/ RotaryMode::kNone); + TORCH_CHECK(status == cudaSuccess, + "Estimate_attn_score failed with error code ", + cudaGetErrorString(status)); + }); + return true; + }); + TORCH_CHECK(success, "Estimate_attn_score failed to dispatch with dtype ", q.scalar_type()); +} \ No newline at end of file diff --git a/quest/raas_ops/csrc/page.cu b/quest/raas_ops/csrc/page.cu new file mode 100644 index 0000000..a43b0aa --- /dev/null +++ b/quest/raas_ops/csrc/page.cu @@ -0,0 +1,253 @@ +#include "bsk_ops.h" +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +void append_kv_cache_decode(torch::Tensor k, + torch::Tensor v, + torch::Tensor kv_data, + torch::Tensor kv_indices, + torch::Tensor kv_indptr, + unsigned int kv_last_page_len, + unsigned int kv_last_page_idx, + torch::Tensor metadata_data, + torch::Tensor metadata_indices, + torch::Tensor metadata_indptr, + unsigned int metadata_last_page_len, + unsigned int metadata_last_page_idx, + unsigned int layout) { + constexpr size_t batch_size = 1; + CHECK_INPUT(k); // [bsz, num_heads, head_dim] + CHECK_INPUT(v); // [bsz, num_heads, head_dim] + // (num_max_pages, 2, H_kv, page_size, head_dim) for HND + // (num_max_pages, 2, page_size, H_kv, head_dim) for NHD + CHECK_INPUT(kv_data); + CHECK_INPUT(kv_indices); // [num_pages] + CHECK_INPUT(metadata_data); + CHECK_INPUT(metadata_indices); // [num_pages] + + CHECK_DIM(1, kv_indices); + CHECK_DIM(1, metadata_indices); + CHECK_DIM(3, k); + CHECK_DIM(3, v); + CHECK_DIM(5, kv_data); + CHECK_DIM(5, metadata_data); + + CHECK_EQ(k.size(0), 1); // decode + CHECK_EQ(v.size(0), 1); // decode + CHECK_EQ(kv_indices.scalar_type(), torch::kInt32); + CHECK_EQ(metadata_indices.scalar_type(), torch::kInt32); + CHECK_EQ(kv_indptr.scalar_type(), torch::kInt32); + CHECK_EQ(metadata_indptr.scalar_type(), torch::kInt32); + + size_t num_heads = k.size(1); + size_t head_dim = k.size(2); + size_t page_size; + QKVLayout kv_layout = static_cast(layout); + if(kv_layout == QKVLayout::kHND) { + page_size = kv_data.size(3); + CHECK_EQ(kv_data.size(2), num_heads); + CHECK_EQ(kv_data.size(4), head_dim); + } else { + page_size = kv_data.size(2); + CHECK_EQ(kv_data.size(3), num_heads); + CHECK_EQ(kv_data.size(4), head_dim); + } + + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(k.scalar_type(), c_type, [&] { + SWITCH_LAYOUT(kv_layout, KV_LAYOUT, { + paged_kv_t paged_kv( + num_heads, + page_size, + head_dim, + batch_size, + 0, + kv_last_page_len, + kv_last_page_idx, + static_cast(kv_data.data_ptr()), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr())); + + paged_kv_t paged_metadata( + num_heads, + page_size, + head_dim, + batch_size, + 0, + metadata_last_page_len, + metadata_last_page_idx, + static_cast(metadata_data.data_ptr()), + static_cast(metadata_indices.data_ptr()), + static_cast(metadata_indptr.data_ptr())); + + cudaError_t status = + AppendPagedKVCacheDecode( + paged_kv, + paged_metadata, + static_cast(k.data_ptr()), + static_cast(v.data_ptr()), + nullptr); + + TORCH_CHECK(status == cudaSuccess, + "Append_kv_cache_decode failed with error code ", + cudaGetErrorString(status)); + }); + return true; + }); + + TORCH_CHECK(success, "Append_kv_cache_decode failed to dispatch with dtype ", k.scalar_type()); +} + +void append_kv_cache_prefill(torch::Tensor k, + torch::Tensor v, + torch::Tensor kv_data, + torch::Tensor kv_indices, + torch::Tensor kv_indptr, + unsigned int kv_last_page_len, + unsigned int kv_last_page_idx, + torch::Tensor metadata_data, + torch::Tensor metadata_indices, + torch::Tensor metadata_indptr, + unsigned int metadata_last_page_len, + unsigned int metadata_last_page_idx, + unsigned int layout) { + constexpr size_t batch_size = 1; + +#ifdef BSK_TORCH_CHECK + CHECK_INPUT(k); // [bsz, num_heads, head_dim] + CHECK_INPUT(v); // [bsz, num_heads, head_dim] + // (num_max_pages, 2, H_kv, page_size, head_dim) for HND + // (num_max_pages, 2, page_size, H_kv, head_dim) for NHD + CHECK_INPUT(kv_data); + CHECK_INPUT(kv_indices); // [num_pages] + CHECK_INPUT(metadata_data); + CHECK_INPUT(metadata_indices); // [num_pages] + + CHECK_DIM(1, kv_indices); + CHECK_DIM(1, metadata_indices); + CHECK_DIM(3, k); + CHECK_DIM(3, v); + CHECK_DIM(5, kv_data); + CHECK_DIM(5, metadata_data); + + CHECK_GE(k.size(0), 2); // Prefill + CHECK_GE(v.size(0), 2); // Prefill + CHECK_EQ(kv_indices.scalar_type(), torch::kInt32); + CHECK_EQ(metadata_indices.scalar_type(), torch::kInt32); + CHECK_EQ(kv_indptr.scalar_type(), torch::kInt32); + CHECK_EQ(metadata_indptr.scalar_type(), torch::kInt32); +#endif + + size_t seq_len = k.size(0); + size_t num_heads = k.size(1); + size_t head_dim = k.size(2); + size_t page_size; + QKVLayout kv_layout = static_cast(layout); + if(kv_layout == QKVLayout::kHND) { + page_size = kv_data.size(3); +#ifdef BSK_TORCH_CHECK + CHECK_EQ(kv_data.size(2), num_heads); + CHECK_EQ(kv_data.size(4), head_dim); +#endif + } else { + page_size = kv_data.size(2); +#ifdef BSK_TORCH_CHECK + CHECK_EQ(kv_data.size(3), num_heads); + CHECK_EQ(kv_data.size(4), head_dim); +#endif + } + +#ifdef BSK_TORCH_CHECK + CHECK_EQ(seq_len, v.size(0)); +#endif + + torch::Tensor append_indptr = + torch::tensor({0, static_cast(seq_len)}, kv_indices.options()); + + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(k.scalar_type(), c_type, [&] { + SWITCH_LAYOUT(kv_layout, KV_LAYOUT, { + paged_kv_t paged_kv( + num_heads, + page_size, + head_dim, + batch_size, + 0, + kv_last_page_len, + kv_last_page_idx, + static_cast(kv_data.data_ptr()), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr())); + + paged_kv_t paged_metadata( + num_heads, + page_size, + head_dim, + batch_size, + 0, + metadata_last_page_len, + metadata_last_page_idx, + static_cast(metadata_data.data_ptr()), + static_cast(metadata_indices.data_ptr()), + static_cast(metadata_indptr.data_ptr())); + + cudaError_t status = + AppendPagedKVCachePrefill( + paged_kv, + paged_metadata, + static_cast(k.data_ptr()), + static_cast(v.data_ptr()), + static_cast(append_indptr.data_ptr()), + nullptr); + + TORCH_CHECK(status == cudaSuccess, + "Append_kv_cache_prefill failed with error code ", + cudaGetErrorString(status)); + }); + return true; + }); + + TORCH_CHECK(success, "Append_kv_cache_prefill failed to dispatch with dtype ", k.scalar_type()); +} + +void apply_rope_in_place(torch::Tensor q, + torch::Tensor k, + unsigned int past_kv_len, + float rope_scale, + float rope_theta) { +#ifdef BSK_TORCH_CHECK + // Note: input layout is always NHD. Not Paged. + CHECK_INPUT(q); // [seq_len, num_heads, head_dim] + CHECK_INPUT(k); // [seq_len, num_heads, head_dim] + + CHECK_DIM(3, q); + CHECK_DIM(3, k); + + CHECK_EQ(q.size(0), k.size(0)); + CHECK_EQ(q.size(1), k.size(1)); + CHECK_EQ(q.size(2), k.size(2)); +#endif + + size_t seq_len = q.size(0); + size_t num_heads = q.size(1); + size_t head_dim = q.size(2); + + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + cudaError_t status = QKApplyRotaryInPlace(static_cast(q.data_ptr()), + static_cast(k.data_ptr()), + seq_len, + past_kv_len, + num_heads, + num_heads, + head_dim, + rope_scale, + rope_theta, + nullptr); + + TORCH_CHECK(status == cudaSuccess, + "apply_rope_in_place failed with error code ", + cudaGetErrorString(status)); + return true; + }); + + TORCH_CHECK(success, "apply_rope_in_place failed to dispatch with dtype ", k.scalar_type()); +} \ No newline at end of file diff --git a/quest/raas_ops/csrc/pytorch_extension_utils.h b/quest/raas_ops/csrc/pytorch_extension_utils.h new file mode 100644 index 0000000..d8bf7f9 --- /dev/null +++ b/quest/raas_ops/csrc/pytorch_extension_utils.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + Modified from FlashInfer PyTorch API. + Check: https://github.com/flashinfer-ai/flashinfer/blob/main/python/csrc/pytorch_extension_utils.h +*/ + +#pragma once +#include + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch(pytorch_dtype) { \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + default: \ + return false; \ + } \ + }() + +inline void check_shape(const torch::Tensor& a, + const torch::Tensor& b, + const char* a_name, + const char* b_name) { + TORCH_CHECK( + a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim()); + for(int i = 0; i < a.dim(); ++i) { + TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")"); + } +} + +inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { + return (uint32_t(a) << 16) | uint32_t(b); +} + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) + +#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) diff --git a/quest/raas_ops/csrc/rms_norm.cu b/quest/raas_ops/csrc/rms_norm.cu new file mode 100644 index 0000000..18ebeb7 --- /dev/null +++ b/quest/raas_ops/csrc/rms_norm.cu @@ -0,0 +1,213 @@ +// Adapted from cutlass +// https://github.com/NVIDIA/cutlass/blob/7d8317a63e0a978a8dbb3c1fb7af4dbe4f286616/tools/util/include/cutlass/util/device_rmsnorm.h +/****************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#include +#include +#include + +#include +#include + +#include "bsk_ops.h" +#include "pytorch_extension_utils.h" + +template +__inline__ __device__ T warpReduceSum(T *val) { +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(0xffffffff, val[i], mask, 32); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceSum(T *val) { + __shared__ T shared[NUM][33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduceSum(val); + + if (lane == 0) { +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[i][lane] : (T)(0.0f); + } + warpReduceSum(val); + return (T)0.0f; +} + +template +__global__ void rmsnorm_twoPassAlgo_e8(float4 *__restrict__ output, + const float4 *__restrict__ input, + const float4 *__restrict__ weight, int m, + int n, float epsilon) { + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean; + float local_sums[1] = {0.0f}; + const int n_8 = n / 8; + int offset = m_idx * n_8; + input += offset; + output += offset; + + for (int index = tid; index < n_8; index += bdimx) { + const float4 local_val = input[index]; + const half2 *h1 = (half2 *)&local_val.x; + const half2 *h2 = (half2 *)&local_val.y; + const half2 *h3 = (half2 *)&local_val.z; + const half2 *h4 = (half2 *)&local_val.w; + local_sums[0] += static_cast(h1->x) * static_cast(h1->x) + + static_cast(h1->y) * static_cast(h1->y) + + static_cast(h2->x) * static_cast(h2->x) + + static_cast(h2->y) * static_cast(h2->y) + + static_cast(h3->x) * static_cast(h3->x) + + static_cast(h3->y) * static_cast(h3->y) + + static_cast(h4->x) * static_cast(h4->x) + + static_cast(h4->y) * static_cast(h4->y); + } + + blockReduceSum(local_sums); + if (threadIdx.x == 0) { + s_mean = rsqrtf(local_sums[0] / n + epsilon); + } + __syncthreads(); + + for (int index = tid; index < n_8; index += bdimx) { + const float4 local_val = input[index]; + const float4 weight_val = weight[index]; + + const half2 *l1 = (half2 *)&local_val.x; + const half2 *l2 = (half2 *)&local_val.y; + const half2 *l3 = (half2 *)&local_val.z; + const half2 *l4 = (half2 *)&local_val.w; + + const half2 *g1 = (half2 *)&weight_val.x; + const half2 *g2 = (half2 *)&weight_val.y; + const half2 *g3 = (half2 *)&weight_val.z; + const half2 *g4 = (half2 *)&weight_val.w; + + float4 tmp; + half2 *h1 = (half2 *)&tmp.x; + half2 *h2 = (half2 *)&tmp.y; + half2 *h3 = (half2 *)&tmp.z; + half2 *h4 = (half2 *)&tmp.w; + + h1->x = static_cast(static_cast(l1->x) * s_mean * + static_cast(g1->x)); + h1->y = static_cast(static_cast(l1->y) * s_mean * + static_cast(g1->y)); + h2->x = static_cast(static_cast(l2->x) * s_mean * + static_cast(g2->x)); + h2->y = static_cast(static_cast(l2->y) * s_mean * + static_cast(g2->y)); + h3->x = static_cast(static_cast(l3->x) * s_mean * + static_cast(g3->x)); + h3->y = static_cast(static_cast(l3->y) * s_mean * + static_cast(g3->y)); + h4->x = static_cast(static_cast(l4->x) * s_mean * + static_cast(g4->x)); + h4->y = static_cast(static_cast(l4->y) * s_mean * + static_cast(g4->y)); + + output[index] = tmp; + } +} + +template +bool rms_norm(T *__restrict__ output, const T *__restrict__ input, + const T *__restrict__ weight, int rows, int columns, + float epsilon) { + if (columns % 8 != 0) { + return false; + } + + dim3 grid(rows); + dim3 block(std::min(1024, (columns / 8 + 31) / 32 * 32)); + + if (std::is_same::value) { + rmsnorm_twoPassAlgo_e8 + <<>>((float4 *)output, (float4 *)input, (float4 *)weight, + rows, columns, epsilon); + return true; + } else if (std::is_same::value) { + rmsnorm_twoPassAlgo_e8 + <<>>((float4 *)output, (float4 *)input, (float4 *)weight, + rows, columns, epsilon); + return true; + } + return false; +} + +void rms_norm_forward( + torch::Tensor input, + torch::Tensor weight, + torch::Tensor output, // empty_like from pyTorch + float epsilon +){ + // shape `(batch, seq_len, embed_dim)` + CHECK_INPUT(input); + CHECK_INPUT(weight); + CHECK_EQ(input.dim(), 3); + + int rows = input.size(1); + int columns = input.size(2); + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(input.scalar_type(), c_type, [&] { + bool status = + rms_norm( + static_cast(output.data_ptr()), + static_cast(input.data_ptr()), + static_cast(weight.data_ptr()), + rows, + columns, + epsilon + ); + + TORCH_CHECK(status == true, + "rms_norm failed with error code "); + return true; + }); +} \ No newline at end of file diff --git a/quest/raas_ops/csrc/topk.cu b/quest/raas_ops/csrc/topk.cu new file mode 100644 index 0000000..2e87645 --- /dev/null +++ b/quest/raas_ops/csrc/topk.cu @@ -0,0 +1,46 @@ +#include "bsk_ops.h" +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +// Note that estimated_indices does not contain the last page +void topk_filtering(torch::Tensor estimated_value, + torch::Tensor estimated_indices, + torch::Tensor d_out, + torch::Tensor indices_out, + torch::Tensor buf, + unsigned int page_budget) { + #ifdef BSK_TORCH_CHECK + CHECK_INPUT(estimated_value); // [num_heads, num_pages] + CHECK_INPUT(estimated_indices); // [num_heads, num_pages] + CHECK_DIM(2, estimated_value); + CHECK_DIM(2, estimated_indices); + #endif + + auto num_heads = estimated_value.size(0); + auto num_pages = estimated_value.size(1); + + #ifdef BSK_TORCH_CHECK + CHECK_EQ(num_pages, estimated_indices.size(1)); + CHECK_EQ(num_heads, estimated_indices.size(0)); + CHECK_GE(num_pages, page_budget); + CHECK_EQ(estimated_indices.scalar_type(), torch::kInt32); + CHECK_EQ(32, num_heads); // Not necessary, but for Llama-7b + CHECK_EQ(page_budget, d_out.size(1)); + CHECK_EQ(page_budget, indices_out.size(1)); + #endif + + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(estimated_value.scalar_type(), c_type, [&] { + decode_select_k( + static_cast(estimated_value.data_ptr()), + static_cast(estimated_indices.data_ptr()), + static_cast(buf.data_ptr()), + num_pages, + page_budget, + static_cast(d_out.data_ptr()), + static_cast(indices_out.data_ptr()), + true); + return true; + }); + TORCH_CHECK(success, "Top-k filtering failed to dispatch with dtype ", estimated_value.scalar_type()); +} \ No newline at end of file diff --git a/quest/raas_ops/setup.sh b/quest/raas_ops/setup.sh new file mode 100644 index 0000000..6b1f647 --- /dev/null +++ b/quest/raas_ops/setup.sh @@ -0,0 +1,15 @@ +mkdir -p build +cd build + +cmake -DCMAKE_PREFIX_PATH=`python -c 'import torch;print(torch.utils.cmake_prefix_path)'` -GNinja .. +ninja + +echo "Compilation Finish" +cd .. +for file in $(find "./build" -maxdepth 1 -name "*.so"); do + abs_file=$(realpath $file) + if [ -e $abs_file ]; then + ln -s $abs_file ../ + echo "Copied $abs_file..." + fi +done \ No newline at end of file diff --git a/quest/raas_utils/__init__.py b/quest/raas_utils/__init__.py new file mode 100644 index 0000000..7a94c43 --- /dev/null +++ b/quest/raas_utils/__init__.py @@ -0,0 +1,277 @@ +import torch +import math +from typing import Optional + +import quest._raas_kernels as _kernels +from quest.raas_utils.utils import TensorLayout +from quest.raas_utils.kv_cache import KvCache +from quest.raas_utils.controller import InferenceController +from quest.raas_utils.decode_wrapper import BatchDecodeWithPagedKVCacheWrapper + +__all__ = [ + 'TensorLayout', + 'KvCache', + 'InferenceController', + "BatchDecodeWithPagedKVCacheWrapper", + "append_kv", + "prefill_forward", + "decode_estimate", + "decode_topk", + "decode_sparse_attn", + "rms_norm_forward", + "apply_rope_in_place", +] + +def apply_rope_in_place( + q: torch.Tensor, + k: torch.Tensor, + past_kv_len: int, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, +): + """ + Semantics of `apply_rope_in_place`: + Apply RoPE (Relative Positional Encoding) in-place. + On q, k which is generated by GEMM. Layout is naturally NHD. + + Args: + q: Shape: `[N, H, D]`. + k: Shape: `[N, H, D]`. + past_kv_len: Length of past KV cache. Used to calculate frequency. + """ + if rope_scale is None: + rope_scale = 1.0 + if rope_theta is None: + rope_theta = 1e4 + _kernels.apply_rope_in_place( + q, + k, + past_kv_len, + rope_scale, + rope_theta, + ) + +def rms_norm_forward( + input: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> torch.Tensor: + o = torch.empty_like(input, dtype=input.dtype, device=input.device) + f = _kernels.rms_norm_forward + f( + input, + weight, + o, + epsilon, + ) + return o + +def append_kv( + k: torch.Tensor, + v: torch.Tensor, + iController: InferenceController, + layer_idx: int, +): + """ + Semantics of `append_kv`: + Append new generated k/v into kv cache and meta data cache. + Automatically dispatch to Prefill / Decode Kernel + + Notations for shapes: + `B`: batch size + `N`: number of heads + `D`: head dimension + `L`: number of layers + `MAXLEN`: maximum length of the KV cache + + Args: + k: Shape: `[B, N, D]`. Key projection (`X @ W_k`). + v: Shape: `[B, N, D]`. Value projection (`X @ W_v`). + iController: InferenceController object, which contains all needed information. + layer_idx: Layer index of the KV cache. + """ + seq_len = k.size(0) + if seq_len > 1: + _kernels.append_kv_cache_prefill( + k, + v, + iController.kv_cache.buf_layer(layer_idx), + iController.kv_indices_with_last, + iController.kv_indptr_for_append, + iController.kv_cache.last_page_len, + iController.kv_last_page_idx, + iController.metadata_cache.buf_layer(layer_idx), + iController.metadata_indices, + iController.metadata_indptr_for_append, + iController.metadata_cache.last_page_len, + iController.metadata_last_page_idx, + iController.layout + ) + else: + _kernels.append_kv_cache_decode( + k, + v, + iController.kv_cache.buf_layer(layer_idx), + iController.kv_indices_with_last, + iController.kv_indptr_for_append, + iController.kv_cache.last_page_len, + iController.kv_last_page_idx, + iController.metadata_cache.buf_layer(layer_idx), + iController.metadata_indices, + iController.metadata_indptr_for_append, + iController.metadata_cache.last_page_len, + iController.metadata_last_page_idx, + iController.layout + ) + +def prefill_forward( + q: torch.Tensor, + iController: InferenceController, + layer_idx: int, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, +) -> torch.Tensor: + """ + Semantics of `prefill_forward`: + New genrated K/Vs are already in the kv cache and meta data cache (well-maintained). + Perform FlashInfer Self-Attention with Casual Attention. + Note that we not have position shift and current version not support Prefill Optimization. + + Notations for shapes: + `B`: batch size + `N`: number of heads + `D`: head dimension + `L`: number of layers + `MAXLEN`: maximum length of the KV cache + + Args: + q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). + iController: InferenceController object, which contains all needed information. + layer_idx: Layer index of the KV cache. + """ + if rope_scale is None: + rope_scale = 1.0 + if rope_theta is None: + rope_theta = 1e4 + + f = _kernels.prefill_with_paged_kv_cache + o = f( + q, + iController.kv_cache.buf_layer(layer_idx), + iController.kv_indices_with_last, + iController.kv_cache.last_page_len, + True, # Casual + iController.layout, + False, # FP16 Accumulator for 4090 + rope_scale, + rope_theta, + ) + return o + +def decode_estimate( + q: torch.Tensor, + iController: InferenceController, + layer_idx: int, +) -> torch.Tensor: + """ + Semantics of `decode_estimate`: + When decoding, estimate the attention score for each page. + + Notations for shapes: + `B`: batch size + `N`: number of heads + `D`: head dimension + `L`: number of layers + `MAXLEN`: maximum length of the KV cache + + Args: + q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). + iController: InferenceController object, which contains all needed information. + layer_idx: Layer index of the KV cache. + """ + f = _kernels.estimate_attn_score + # (iController.metadata_cache.seqlen - 1) is manually excluding the last elements, which is the current page. + o = torch.empty((iController.num_heads, iController.metadata_cache.seqlen - 1), dtype=q.dtype, device=q.device) + f( + q, + o, + iController.metadata_cache.buf_layer(layer_idx), + iController.metadata_indices, + iController.metadata_indptr_for_append, + iController.metadata_cache.last_page_len, # One entry delta is considered by kernel-level implementation + iController.metadata_last_page_idx, + iController.layout, + ) + return o + +def decode_topk( + estimated_attn_score: torch.Tensor, + iController: InferenceController, + layer_idx: int, +): + """ + Semantics of `decode_topk`: + select top-k pages with highest attention score. + + Notations for shapes: + `B`: batch size + `N`: number of heads + `D`: head dimension + `L`: number of layers + `MAXLEN`: maximum length of the KV cache + + Args: + q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). + iController: InferenceController object, which contains all needed information. + layer_idx: Layer index of the KV cache. + """ + # excluding the last page; half of the max-pages limited. This pages are then flagged as the top-k pages. + page_budet = min(iController.inference_page_budget - 1, iController.origin_page_budget // 2) + f = _kernels.topk_filtering + f( + estimated_attn_score, + iController.kv_indices_without_last, + iController.topk_dout_buffer, + iController.topk_dindices_buffer[layer_idx], + iController.topk_buf, + page_budet, + ) + +def decode_sparse_attn( + q: torch.Tensor, + iController: InferenceController, + layer_idx: int, + topk_indices: torch.Tensor, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, +) -> torch.Tensor: + """ + Semantics of `decode_sparse_attn`: + Excute self-attention only on the selected pages (Top-k output) + + Notations for shapes: + `B`: batch size + `N`: number of heads + `D`: head dimension + `L`: number of layers + `MAXLEN`: maximum length of the KV cache + + Args: + q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). + iController: InferenceController object, which contains all needed information. + layer_idx: Layer index of the KV cache. + topk_indices: Shape: `[N, page_budget-1]`. Top-k indices. + """ + o = torch.empty_like(q, dtype=q.dtype, device=q.device) + iController._decode_handler.forward( + q, + o, + iController.kv_cache.buf_layer(layer_idx), + topk_indices, + iController.kv_indptr_for_approx_decode, + iController.kv_cache.last_page_len, + iController.kv_last_page_idx, + rope_scale, + rope_theta, + ) + return o \ No newline at end of file diff --git a/quest/raas_utils/controller.py b/quest/raas_utils/controller.py new file mode 100644 index 0000000..4dad4a2 --- /dev/null +++ b/quest/raas_utils/controller.py @@ -0,0 +1,208 @@ +from quest.raas_utils.decode_wrapper import BatchDecodeWithPagedKVCacheWrapper +from quest.raas_utils.kv_cache import KvCache +from quest.raas_utils.utils import TensorLayout +from logging import getLogger +import torch + +logger = getLogger(__name__) + +class InferenceController: + def __init__( + self, + num_layers, + num_heads, + head_dim, + page_size, + page_budget, # Real page budget including the last page + max_seq_len, # Real max for allocating kv / metadata + max_page_limit, + dtype, + device, + ): + max_kv_pages_num = (max_seq_len + page_size - 1) // page_size + self.max_kv_pages_num = max_kv_pages_num + self.num_layers = num_layers + self.kv_cache = KvCache( + num_layers=num_layers, + num_heads=num_heads, + head_dim=head_dim, + max_seq_len=max_seq_len, + page_size=page_size, + dtype=dtype, + device=device + ) + self.metadata_cache = KvCache( + num_layers=num_layers, + num_heads=num_heads, + head_dim=head_dim, + max_seq_len=max_kv_pages_num, + page_size=page_size, + dtype=dtype, + device=device + ) + self.layout = TensorLayout.NHD # Arbitrarily choose NHD. + self.device = device + self.dtype = dtype + + self.num_heads = num_heads + self.head_dim = head_dim + self.page_size = page_size + self.max_page_limit = max_page_limit + + self._page_budget = page_budget + self.origin_page_budget = page_budget + self._decode_handler = BatchDecodeWithPagedKVCacheWrapper(kv_layout="NHD") + + self.kv_indices_with_last = None + self.kv_indices_without_last = None + self.metadata_indices = None + self.kv_last_page_idx = None # For decoding self-attention + self.metadata_last_page_idx = None + + self.kv_indptr_for_append = None + self.metadata_indptr_for_append = None + self.kv_indptr_for_approx_decode = None + + self.inference_page_budget = None + + self.topk_dout_buffer = None + self.topk_dindices_buffer = None + self.topk_buf = None + + self.saved_page_num = None + self.saved_page_index = None + self.saved_pages = None + self.saved_pages_continuous = None + # used by raas: the time stamps + self.timestamps = None + # breakpoint() + self.clean_states() + + + # Used for controlling the number of pages + # Here we skip first two layers by manipulating this. + def set_page_budget(self, page_budget: int): + self._page_budget = page_budget + + # Called once per forwarding in all layers + # Adjust the metadata for paged_kv + def prepare_metadata(self, seq_len: int): + # breakpoint() + # Allocate entry for tokens + old_pages = len(self.kv_cache.indicies) + appended_new_pages = self.kv_cache.append_seq(seq_len) + # Allocate entry for metadata + _ = self.metadata_cache.append_seq(appended_new_pages) + now_pages = len(self.kv_cache.indicies) + + self.timestamps[:, :, old_pages:now_pages] = self.kv_cache.seqlen # newest timestamp + self.saved_page_num += appended_new_pages + + if self.saved_page_num > self.origin_page_budget: + if seq_len > 1: + logger.warning("Evicing while prefill!") + self.saved_page_num = self.origin_page_budget + else: + assert self.saved_page_num == self.origin_page_budget + 1 + # breakpoint() + oldest_ind = torch.argmin(self.timestamps[:, :, :now_pages], dim=-1).unsqueeze(-1).long() # shape: [32, 32, 1] + pages_ind = torch.gather(self.saved_page_index, dim=-1, index=oldest_ind).long() # shape: [32, 32, 1] + self.saved_pages.scatter_(dim=2, index=pages_ind, value=now_pages - 1) + self.timestamps.scatter_(dim=2, index=oldest_ind, value=torch.iinfo(torch.int32).max) + self.saved_page_index.scatter_(dim=2, index=oldest_ind, src=pages_ind) + self.saved_pages_continuous = self.saved_pages[:, :, :self.saved_page_num].sort(dim=-1).values + self.saved_page_num -= 1 + + + def update_timestamp(self): + topk = self.topk_dindices_buffer + num_layer, batch_size, K = topk.shape + layer_indices = torch.arange(num_layer, device=self.device).unsqueeze(1).unsqueeze(2) + batch_indices = torch.arange(batch_size, device=self.device).unsqueeze(0).unsqueeze(2).expand(num_layer, -1, K) + current_values = self.timestamps[layer_indices, batch_indices, topk] + target_value = torch.full_like(current_values, self.kv_cache.seqlen, device=self.device) + updated_values = torch.max(current_values, target_value) + self.timestamps[layer_indices, batch_indices, topk] = updated_values + + def get_saved_pages(self, layer_idx: int): + return self.saved_pages[layer_idx, :, :] + + # Prepare metadata used for inference under certain PAGE_BUDGET + # Called multiple times for layer sensitivity + def begin_forward(self, seq_len: int, updateTensor: bool = True): + # Allocate tensor in advance + # This is used for append kernels, which need original indices + if updateTensor: + self.kv_indptr_for_append = torch.tensor([0, len(self.kv_cache.indicies)], dtype=torch.int32, device=self.device) + self.metadata_indptr_for_append = torch.tensor([0, len(self.metadata_cache.indicies)], dtype=torch.int32, device=self.device) + self.kv_last_page_idx = self.kv_cache.indicies[-1] + self.metadata_last_page_idx = self.metadata_cache.indicies[-1] + + if seq_len > 1: + # prefill requests + # append_kv_cache_prefill and prefill_with_paged_kv_cache + if updateTensor: + self.kv_indices_with_last = torch.tensor(self.kv_cache.indicies, dtype=torch.int32, device=self.device) + self.metadata_indices = torch.tensor(self.metadata_cache.indicies, dtype=torch.int32, device=self.device) + else: + # decode requests + # append_kv_cache_decode, estimate_attn_score, topk_filtering + cur_page_nums = len(self.kv_cache.indicies) + assert cur_page_nums > 1 # at least two pages for excluding last page + + if updateTensor: + # used for appending + self.kv_indices_with_last = torch.tensor(self.kv_cache.indicies, dtype=torch.int32, device=self.device) + + # Only used for top-k filtering (because we manully exclude the last page) as input index + self.kv_indices_without_last = torch.tensor(self.kv_cache.indicies[:-1], dtype=torch.int32, device=self.device).repeat(self.num_heads, 1) + + # used for estimate + self.metadata_indices = torch.tensor(self.metadata_cache.indicies, dtype=torch.int32, device=self.device) + + # used as page_budget for topk and approx kernel + self.inference_page_budget = min(self._page_budget, cur_page_nums) + + # Exclude the last page for decoding + self.kv_indptr_for_approx_decode = torch.tensor([0, self.inference_page_budget - 1], dtype=torch.int32, device=self.device) + + # Allocate buffer for top-k filtering + page_budet = min(self.inference_page_budget - 1, self.origin_page_budget // 2) + self.topk_dout_buffer = torch.zeros((self.num_heads, page_budet), dtype=self.dtype, device=self.device) + self.topk_dindices_buffer = torch.zeros((self.num_layers, self.num_heads, page_budet), dtype=torch.int32, device=self.device) + self.topk_buf = torch.zeros((self.num_heads, 8192 * 2 * (2+4) // 2 // 48), dtype=self.dtype, device=self.device) + + self._decode_handler.begin_forward( + self.kv_indptr_for_approx_decode, + self.num_heads, + self.num_heads, + self.head_dim, + self.page_size, + self.dtype + ) + + # Used for releasing resources + # Free memory in CUDA side + # called multiple times for layer sensitivity + def end_forward(self): + self._decode_handler.end_forward() + + # def need_estimate(self) -> bool: + # if self.inference_page_budget is None: + # return False + + # cur_page_nums = len(self.kv_cache.indicies) + # return cur_page_nums > self.inference_page_budget + + def clean_states(self): + # breakpoint() + # used by raas: pages that are saved + self.saved_page_num = 0 + self.saved_page_index = torch.arange(self.max_kv_pages_num, dtype=torch.int64, device=self.device).repeat(self.num_layers, self.num_heads, 1) + self.saved_pages = torch.arange(self.max_kv_pages_num, dtype=torch.int32, device=self.device).repeat(self.num_layers, self.num_heads, 1) + self.saved_pages_continuous = self.saved_pages[:, :, :self.saved_page_num].sort(dim=-1).values + # used by raas: the time stamps + self.timestamps = torch.zeros((self.num_layers, self.num_heads, self.max_kv_pages_num), dtype=torch.int32, device=self.device) + # breakpoint() + self.kv_cache.release() + self.metadata_cache.release() diff --git a/quest/utils/decode_wrapper.py b/quest/raas_utils/decode_wrapper.py similarity index 96% rename from quest/utils/decode_wrapper.py rename to quest/raas_utils/decode_wrapper.py index 4cec8e7..c147d08 100644 --- a/quest/utils/decode_wrapper.py +++ b/quest/raas_utils/decode_wrapper.py @@ -1,8 +1,8 @@ import torch from typing import Optional -import quest._kernels as _kernels -from quest.utils.utils import TensorLayout +import quest._raas_kernels as _kernels +from quest.raas_utils.utils import TensorLayout def _check_kv_layout(kv_layout: str): if not hasattr(TensorLayout, kv_layout): diff --git a/quest/raas_utils/kv_cache.py b/quest/raas_utils/kv_cache.py new file mode 100644 index 0000000..114cc05 --- /dev/null +++ b/quest/raas_utils/kv_cache.py @@ -0,0 +1,135 @@ +# This file is modified from Punica Project +# Check ref: https://github.com/punica-ai/punica + +from quest.raas_utils.utils import TensorLayout +import torch + +class KvPool: + + def __init__( + self, + num_layers: int, + num_heads: int, + head_dim: int, + capacity: int, + block_len: int, + dtype: torch.dtype, + device: torch.device, + ): + self._layout = TensorLayout.NHD + self._buf = torch.empty( + (num_layers, capacity, 2, block_len, num_heads, head_dim), + dtype=dtype, + device=device) + + # 32 layers are identical + self._free = set(range(capacity)) + + @property + def layout(self): + return self._layout + + @property + def buf(self): + return self._buf + + @property + def num_layers(self): + l, c, _, p, n, d = self._buf.shape + return l + + @property + def block_len(self): + l, c, _, p, n, d = self._buf.shape + return p + + @property + def num_free_blocks(self): + return len(self._free) + + @property + def capacity(self): + l, c, _, p, n, d = self._buf.shape + return c + + def alloc_block(self) -> int: + idx = self._free.pop() + return idx + + def free_block(self, idx: int): + assert 0 <= idx < self.capacity + assert idx not in self._free + self._free.add(idx) + def arrange(self): + self._free = set(range(self.capacity)) + +class KvCache: + """Key-value cache for one sequence.""" + + def __init__( + self, + num_layers, + num_heads, + head_dim, + max_seq_len: int, + page_size, + dtype: torch.dtype, + device: torch.device + ): + + if max_seq_len <= 0: + raise ValueError("init_len must be non-negative") + + self._pool = KvPool( + num_layers=num_layers, + num_heads=num_heads, + head_dim=head_dim, + capacity=(max_seq_len + page_size - 1) // page_size, + block_len=page_size, + dtype=dtype, + device=device + ) + + self._indicies = [] + self._seqlen = 0 + + @property + def pool(self) -> KvPool: + return self._pool + + @property + def seqlen(self) -> int: + return self._seqlen + + @property + def last_page_len(self) -> int: + return (self.seqlen - 1) % self._pool.block_len + 1 + + @property + def indicies(self) -> list[int]: + return self._indicies + + def buf_layer(self, layer_idx: int): + assert layer_idx < self.pool.num_layers + return self._pool.buf[layer_idx] + + def append_seq(self, seq_len: int) -> int: + """Reserve space for tokens and return number of new pages""" + if seq_len <= 0: + return 0 + appended_page_count = 0 + for _ in range(seq_len): + last_page_offset = self.last_page_len + if last_page_offset == self._pool.block_len: + self._indicies.append(self._pool.alloc_block()) + appended_page_count += 1 + self._seqlen += 1 + return appended_page_count + + def release(self): + """Release all blocks""" + self._seqlen = 0 + for idx in self._indicies: + self._pool.free_block(idx) + self._pool.arrange() + self._indicies.clear() \ No newline at end of file diff --git a/quest/raas_utils/utils.py b/quest/raas_utils/utils.py new file mode 100644 index 0000000..4f88983 --- /dev/null +++ b/quest/raas_utils/utils.py @@ -0,0 +1,5 @@ +class TensorLayout: + NHD = 0 + HND = 1 + + FORMAT2STR = {0: "NHD", 1: "HND"} \ No newline at end of file diff --git a/quest/tests/test_approx_attention.py b/quest/tests/test_approx_attention.py index ba10ba8..70674c9 100644 --- a/quest/tests/test_approx_attention.py +++ b/quest/tests/test_approx_attention.py @@ -5,7 +5,7 @@ import torch.nn as nn import math -import quest.utils +import quest.quest_utils def assert_close(a, b): rtol, atol = { @@ -136,7 +136,7 @@ def test_approx_attention_correctness(dtype_str, qo_len, kv_len, page_budget): k_prefill = torch.randn(kv_len-1, num_heads, head_dim, dtype=dtype, device=device) v_prefill = torch.randn(kv_len-1, num_heads, head_dim, dtype=dtype, device=device) - testController = quest.utils.InferenceController( + testController = quest.quest_utils.InferenceController( num_layers, num_heads, head_dim, @@ -151,7 +151,7 @@ def test_approx_attention_correctness(dtype_str, qo_len, kv_len, page_budget): testController.prepare_metadata(kv_len-1) testController.begin_forward(kv_len-1) # Construct KV - quest.utils.append_kv(k_prefill, v_prefill, testController, 0) + quest.quest_utils.append_kv(k_prefill, v_prefill, testController, 0) testController.end_forward() k_decode = torch.randn(1, num_heads, head_dim, dtype=dtype, device=device) @@ -165,10 +165,10 @@ def test_approx_attention_correctness(dtype_str, qo_len, kv_len, page_budget): # CUDA Evaluation testController.prepare_metadata(qo_len) testController.begin_forward(qo_len) - quest.utils.append_kv(k_decode, v_decode, testController, 0) + quest.quest_utils.append_kv(k_decode, v_decode, testController, 0) if testController.need_estimate() == False: - o_device = quest.utils.decode_sparse_attn( + o_device = quest.quest_utils.decode_sparse_attn( q, testController, 0, @@ -188,7 +188,7 @@ def test_approx_attention_correctness(dtype_str, qo_len, kv_len, page_budget): # estimated_attn_score, # testController, # ) - o_device = quest.utils.decode_sparse_attn( + o_device = quest.quest_utils.decode_sparse_attn( q, testController, 0, diff --git a/quest/tests/test_decode_attention.py b/quest/tests/test_decode_attention.py index c8b3681..bd494ca 100644 --- a/quest/tests/test_decode_attention.py +++ b/quest/tests/test_decode_attention.py @@ -5,7 +5,7 @@ import torch.nn as nn import math -import quest.utils +import quest.quest_utils def assert_close(a, b): rtol, atol = { @@ -69,7 +69,7 @@ def test_decode_attention_correctness(dtype_str, qo_len, kv_len): k_prefill = torch.randn(kv_len-1, num_heads, head_dim, dtype=dtype, device=device) v_prefill = torch.randn(kv_len-1, num_heads, head_dim, dtype=dtype, device=device) - testController = quest.utils.InferenceController( + testController = quest.quest_utils.InferenceController( num_layers, num_heads, head_dim, @@ -84,7 +84,7 @@ def test_decode_attention_correctness(dtype_str, qo_len, kv_len): testController.prepare_metadata(kv_len-1) testController.begin_forward(kv_len-1) # Construct KV - quest.utils.append_kv(k_prefill, v_prefill, testController, 0) + quest.quest_utils.append_kv(k_prefill, v_prefill, testController, 0) testController.end_forward() k_decode = torch.randn(1, num_heads, head_dim, dtype=dtype, device=device) @@ -92,10 +92,10 @@ def test_decode_attention_correctness(dtype_str, qo_len, kv_len): # Real decoding starts testController.prepare_metadata(1) testController.begin_forward(1) - quest.utils.append_kv(k_decode, v_decode, testController, 0) + quest.quest_utils.append_kv(k_decode, v_decode, testController, 0) # No CPU test cases assert testController.need_estimate() == False - o_device = quest.utils.decode_sparse_attn( + o_device = quest.quest_utils.decode_sparse_attn( q, testController, 0, diff --git a/quest/tests/test_estimate.py b/quest/tests/test_estimate.py index 0a95fd0..1f0b29e 100644 --- a/quest/tests/test_estimate.py +++ b/quest/tests/test_estimate.py @@ -5,7 +5,7 @@ import torch.nn as nn import math -import quest.utils +import quest.quest_utils def assert_close(a, b): rtol, atol = { @@ -103,7 +103,7 @@ def test_estimate_correctness(dtype_str, kv_len): k_prefill = torch.randn(kv_len-1, num_heads, head_dim, dtype=dtype, device=device) v_prefill = torch.randn(kv_len-1, num_heads, head_dim, dtype=dtype, device=device) - testController = quest.utils.InferenceController( + testController = quest.quest_utils.InferenceController( num_layers, num_heads, head_dim, @@ -118,7 +118,7 @@ def test_estimate_correctness(dtype_str, kv_len): testController.prepare_metadata(kv_len-1) testController.begin_forward(kv_len-1) # Construct KV - quest.utils.append_kv(k_prefill, v_prefill, testController, 0) + quest.quest_utils.append_kv(k_prefill, v_prefill, testController, 0) testController.end_forward() k_decode = torch.randn(1, num_heads, head_dim, dtype=dtype, device=device) @@ -127,8 +127,8 @@ def test_estimate_correctness(dtype_str, kv_len): # CUDA Evaluation testController.prepare_metadata(qo_len) testController.begin_forward(qo_len) - quest.utils.append_kv(k_decode, v_decode, testController, 0) - cuda_estimated_value = quest.utils.decode_estimate( + quest.quest_utils.append_kv(k_decode, v_decode, testController, 0) + cuda_estimated_value = quest.quest_utils.decode_estimate( q, testController, 0, diff --git a/quest/tests/test_prefill_attention.py b/quest/tests/test_prefill_attention.py index 24f6224..5614826 100644 --- a/quest/tests/test_prefill_attention.py +++ b/quest/tests/test_prefill_attention.py @@ -5,7 +5,7 @@ import torch.nn as nn import math -import quest.utils +import quest.quest_utils def assert_close(a, b): rtol, atol = { @@ -65,7 +65,7 @@ def test_prefill_attention_correctness(dtype_str, qo_len, kv_len): k = torch.randn(kv_len, num_heads, head_dim, dtype=dtype, device=device) v = torch.randn(kv_len, num_heads, head_dim, dtype=dtype, device=device) - testController = quest.utils.InferenceController( + testController = quest.quest_utils.InferenceController( num_layers, num_heads, head_dim, @@ -80,8 +80,8 @@ def test_prefill_attention_correctness(dtype_str, qo_len, kv_len): testController.prepare_metadata(kv_len) testController.begin_forward(kv_len) # Construct KV with maintained metadata - quest.utils.append_kv(k, v, testController, 0) - o_device = quest.utils.prefill_forward(q, testController, 0) + quest.quest_utils.append_kv(k, v, testController, 0) + o_device = quest.quest_utils.prefill_forward(q, testController, 0) o_host = _ref_self_attention(q, k, v) assert_close(o_device, o_host) \ No newline at end of file diff --git a/quest/tests/test_rope.py b/quest/tests/test_rope.py index 7cda3e1..707a3e6 100644 --- a/quest/tests/test_rope.py +++ b/quest/tests/test_rope.py @@ -4,7 +4,7 @@ import torch.nn as nn from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb -import quest.utils +import quest.quest_utils def assert_close(a, b): rtol, atol = { @@ -44,7 +44,7 @@ def test_apply_qk_rope(dtype_str, past_kv_len, seq_len): k = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device=device) q_ref, k_ref = _ref_apply_qk_rope(q, k, past_kv_len) - quest.utils.apply_rope_in_place(q, k, past_kv_len) + quest.quest_utils.apply_rope_in_place(q, k, past_kv_len) assert_close(q, q_ref) assert_close(k, k_ref) \ No newline at end of file diff --git a/quest/tests/test_topk.py b/quest/tests/test_topk.py index 54a6f78..4d39718 100644 --- a/quest/tests/test_topk.py +++ b/quest/tests/test_topk.py @@ -5,7 +5,7 @@ import torch.nn as nn import math -import quest.utils +import quest.quest_utils # This file is used for testing topk kernel from libRAFT # We do not seriously compare the topk indices since the random value leads to similar tensor. @@ -50,7 +50,7 @@ def test_topk_correctness(dtype_str, kv_len, k_budget): cuda_output_indices = torch.arange(0, k_budget, dtype=torch.int32, device=device).repeat(num_heads, 1) topk_buf = torch.zeros((num_heads, 8192 * 2 * (2+4) // 2 // 48), dtype=dtype, device=device) - quest.utils._kernels.topk_filtering( + quest.quest_utils._quest_kernels.topk_filtering( cuda_input_data, cuda_input_indices, cuda_output_data, diff --git a/quest/utils/__init__.py b/quest/utils/__init__.py index d42e80e..e69de29 100644 --- a/quest/utils/__init__.py +++ b/quest/utils/__init__.py @@ -1,276 +0,0 @@ -# import torch -# import math -# from typing import Optional - -# import quest._kernels as _kernels -# from quest.utils.utils import TensorLayout -# from quest.utils.kv_cache import KvCache -# from quest.utils.controller import InferenceController -# from quest.utils.decode_wrapper import BatchDecodeWithPagedKVCacheWrapper - -# __all__ = [ -# 'TensorLayout', -# 'KvCache', -# 'InferenceController', -# "BatchDecodeWithPagedKVCacheWrapper", -# "append_kv", -# "prefill_forward", -# "decode_estimate", -# "decode_topk", -# "decode_sparse_attn", -# "rms_norm_forward", -# "apply_rope_in_place", -# ] - -# def apply_rope_in_place( -# q: torch.Tensor, -# k: torch.Tensor, -# past_kv_len: int, -# rope_scale: Optional[float] = None, -# rope_theta: Optional[float] = None, -# ): -# """ -# Semantics of `apply_rope_in_place`: -# Apply RoPE (Relative Positional Encoding) in-place. -# On q, k which is generated by GEMM. Layout is naturally NHD. - -# Args: -# q: Shape: `[N, H, D]`. -# k: Shape: `[N, H, D]`. -# past_kv_len: Length of past KV cache. Used to calculate frequency. -# """ -# if rope_scale is None: -# rope_scale = 1.0 -# if rope_theta is None: -# rope_theta = 1e4 -# _kernels.apply_rope_in_place( -# q, -# k, -# past_kv_len, -# rope_scale, -# rope_theta, -# ) - -# def rms_norm_forward( -# input: torch.Tensor, -# weight: torch.Tensor, -# epsilon: float, -# ) -> torch.Tensor: -# o = torch.empty_like(input, dtype=input.dtype, device=input.device) -# f = _kernels.rms_norm_forward -# f( -# input, -# weight, -# o, -# epsilon, -# ) -# return o - -# def append_kv( -# k: torch.Tensor, -# v: torch.Tensor, -# iController: InferenceController, -# layer_idx: int, -# ): -# """ -# Semantics of `append_kv`: -# Append new generated k/v into kv cache and meta data cache. -# Automatically dispatch to Prefill / Decode Kernel - -# Notations for shapes: -# `B`: batch size -# `N`: number of heads -# `D`: head dimension -# `L`: number of layers -# `MAXLEN`: maximum length of the KV cache - -# Args: -# k: Shape: `[B, N, D]`. Key projection (`X @ W_k`). -# v: Shape: `[B, N, D]`. Value projection (`X @ W_v`). -# iController: InferenceController object, which contains all needed information. -# layer_idx: Layer index of the KV cache. -# """ -# seq_len = k.size(0) -# if seq_len > 1: -# _kernels.append_kv_cache_prefill( -# k, -# v, -# iController.kv_cache.buf_layer(layer_idx), -# iController.kv_indices_with_last, -# iController.kv_indptr_for_append, -# iController.kv_cache.last_page_len, -# iController.kv_last_page_idx, -# iController.metadata_cache.buf_layer(layer_idx), -# iController.metadata_indices, -# iController.metadata_indptr_for_append, -# iController.metadata_cache.last_page_len, -# iController.metadata_last_page_idx, -# iController.layout -# ) -# else: -# _kernels.append_kv_cache_decode( -# k, -# v, -# iController.kv_cache.buf_layer(layer_idx), -# iController.kv_indices_with_last, -# iController.kv_indptr_for_append, -# iController.kv_cache.last_page_len, -# iController.kv_last_page_idx, -# iController.metadata_cache.buf_layer(layer_idx), -# iController.metadata_indices, -# iController.metadata_indptr_for_append, -# iController.metadata_cache.last_page_len, -# iController.metadata_last_page_idx, -# iController.layout -# ) - -# def prefill_forward( -# q: torch.Tensor, -# iController: InferenceController, -# layer_idx: int, -# rope_scale: Optional[float] = None, -# rope_theta: Optional[float] = None, -# ) -> torch.Tensor: -# """ -# Semantics of `prefill_forward`: -# New genrated K/Vs are already in the kv cache and meta data cache (well-maintained). -# Perform FlashInfer Self-Attention with Casual Attention. -# Note that we not have position shift and current version not support Prefill Optimization. - -# Notations for shapes: -# `B`: batch size -# `N`: number of heads -# `D`: head dimension -# `L`: number of layers -# `MAXLEN`: maximum length of the KV cache - -# Args: -# q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). -# iController: InferenceController object, which contains all needed information. -# layer_idx: Layer index of the KV cache. -# """ -# if rope_scale is None: -# rope_scale = 1.0 -# if rope_theta is None: -# rope_theta = 1e4 - -# f = _kernels.prefill_with_paged_kv_cache -# o = f( -# q, -# iController.kv_cache.buf_layer(layer_idx), -# iController.kv_indices_with_last, -# iController.kv_cache.last_page_len, -# True, # Casual -# iController.layout, -# False, # FP16 Accumulator for 4090 -# rope_scale, -# rope_theta, -# ) -# return o - -# def decode_estimate( -# q: torch.Tensor, -# iController: InferenceController, -# layer_idx: int, -# ) -> torch.Tensor: -# """ -# Semantics of `decode_estimate`: -# When decoding, estimate the attention score for each page. - -# Notations for shapes: -# `B`: batch size -# `N`: number of heads -# `D`: head dimension -# `L`: number of layers -# `MAXLEN`: maximum length of the KV cache - -# Args: -# q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). -# iController: InferenceController object, which contains all needed information. -# layer_idx: Layer index of the KV cache. -# """ -# f = _kernels.estimate_attn_score -# # (iController.metadata_cache.seqlen - 1) is manually excluding the last elements, which is the current page. -# o = torch.empty((iController.num_heads, iController.metadata_cache.seqlen - 1), dtype=q.dtype, device=q.device) -# f( -# q, -# o, -# iController.metadata_cache.buf_layer(layer_idx), -# iController.metadata_indices, -# iController.metadata_indptr_for_append, -# iController.metadata_cache.last_page_len, # One entry delta is considered by kernel-level implementation -# iController.metadata_last_page_idx, -# iController.layout, -# ) -# return o - -# def decode_topk( -# estimated_attn_score: torch.Tensor, -# iController: InferenceController, -# ): -# """ -# Semantics of `decode_topk`: -# select top-k pages with highest attention score. - -# Notations for shapes: -# `B`: batch size -# `N`: number of heads -# `D`: head dimension -# `L`: number of layers -# `MAXLEN`: maximum length of the KV cache - -# Args: -# q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). -# iController: InferenceController object, which contains all needed information. -# layer_idx: Layer index of the KV cache. -# """ -# # excluding the last page -# page_budet = iController.inference_page_budget - 1 -# f = _kernels.topk_filtering -# f( -# estimated_attn_score, -# iController.kv_indices_without_last, -# iController.topk_dout_buffer, -# iController.topk_dindices_buffer, -# iController.topk_buf, -# page_budet, -# ) - -# def decode_sparse_attn( -# q: torch.Tensor, -# iController: InferenceController, -# layer_idx: int, -# topk_indices: torch.Tensor, -# rope_scale: Optional[float] = None, -# rope_theta: Optional[float] = None, -# ) -> torch.Tensor: -# """ -# Semantics of `decode_sparse_attn`: -# Excute self-attention only on the selected pages (Top-k output) - -# Notations for shapes: -# `B`: batch size -# `N`: number of heads -# `D`: head dimension -# `L`: number of layers -# `MAXLEN`: maximum length of the KV cache - -# Args: -# q: Shape: `[B, N, D]`. Key projection (`X @ W_k`). -# iController: InferenceController object, which contains all needed information. -# layer_idx: Layer index of the KV cache. -# topk_indices: Shape: `[N, page_budget-1]`. Top-k indices. -# """ -# o = torch.empty_like(q, dtype=q.dtype, device=q.device) -# iController._decode_handler.forward( -# q, -# o, -# iController.kv_cache.buf_layer(layer_idx), -# topk_indices, -# iController.kv_indptr_for_approx_decode, -# iController.kv_cache.last_page_len, -# iController.kv_last_page_idx, -# rope_scale, -# rope_theta, -# ) -# return o \ No newline at end of file diff --git a/quest/utils/cache_utils.py.bak b/quest/utils/cache_utils.py.bak deleted file mode 100644 index d6b3686..0000000 --- a/quest/utils/cache_utils.py.bak +++ /dev/null @@ -1,383 +0,0 @@ -import copy -import importlib.metadata -import json -import os -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch - -from transformers.utils import logging -from transformers.utils.deprecation import deprecate_kwarg - - -logger = logging.get_logger(__name__) - - -class Cache(torch.nn.Module): - """ - Base, abstract class for all caches. The actual data structure is specific to each subclass. - """ - - def __init__(self): - super().__init__() - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. These are specific to each subclass and allow new types of - cache to be created. - - Return: - A tuple containing the updated key and value states. - """ - raise NotImplementedError("Make sure to implement `update` in a subclass.") - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") - - def get_max_length(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states, if there is any.""" - raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") - - def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: - """Given the sequence length of the new inputs, returns the usable length of the cache.""" - # Cache without size limit -> all cache is usable - # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache - # length, we will need to evict part of the cache (and thus not all cache is usable) - max_length = self.get_max_length() - previous_seq_length = self.get_seq_length(layer_idx) - if max_length is not None and previous_seq_length + new_seq_length > max_length: - return max_length - new_seq_length - return previous_seq_length - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx] != []: - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - if self.value_cache[layer_idx] != []: - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - def reset_cache(): - """ - Reset the cache to its initial state. - """ - raise NotImplementedError("Make sure to implement `reset_cache` in a subclass.") - - - @property - def seen_tokens(self): - logger.warning_once( - "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " - "model input instead." - ) - if hasattr(self, "_seen_tokens"): - return self._seen_tokens - else: - return None - - -@dataclass -class CacheConfig: - """ - Base class for cache configs - """ - - cache_implementation: None - - @classmethod - def from_dict(cls, config_dict, **kwargs): - """ - Constructs a CacheConfig instance from a dictionary of parameters. - Args: - config_dict (Dict[str, Any]): Dictionary containing configuration parameters. - **kwargs: Additional keyword arguments to override dictionary values. - - Returns: - CacheConfig: Instance of CacheConfig constructed from the dictionary. - """ - config = cls(**config_dict) - to_remove = [] - for key, value in kwargs.items(): - if hasattr(config, key): - setattr(config, key, value) - to_remove.append(key) - for key in to_remove: - kwargs.pop(key, None) - return config - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file - def to_json_file(self, json_file_path: Union[str, os.PathLike]): - """ - Save this instance to a JSON file. - - Args: - json_file_path (`str` or `os.PathLike`): - Path to the JSON file in which this configuration instance's parameters will be saved. - use_diff (`bool`, *optional*, defaults to `True`): - If set to `True`, only the difference between the config instance and the default - `QuantizationConfig()` is serialized to JSON file. - """ - with open(json_file_path, "w", encoding="utf-8") as writer: - config_dict = self.to_dict() - json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" - - writer.write(json_string) - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict - def to_dict(self) -> Dict[str, Any]: - """ - Serializes this instance to a Python dictionary. Returns: - `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. - """ - return copy.deepcopy(self.__dict__) - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ - def __iter__(self): - """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" - for attr, value in copy.deepcopy(self.__dict__).items(): - yield attr, value - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ - def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" - - def to_json_string(self): - """ - Serializes this instance to a JSON formatted string. - Returns: - str: JSON formatted string representing the configuration instance. - """ - return json.dumps(self.__dict__, indent=2) + "\n" - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update - def update(self, **kwargs): - """ - Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, - returning all the unused kwargs. - - Args: - kwargs (`Dict[str, Any]`): - Dictionary of attributes to tentatively update this class. - - Returns: - `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. - """ - to_remove = [] - for key, value in kwargs.items(): - if hasattr(self, key): - setattr(self, key, value) - to_remove.append(key) - - # Remove all the attributes that were updated, without modifying the input dict - unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} - return unused_kwargs - - -class SinkCache(Cache): - """ - A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to - generate beyond the length of its context window, without losing fluency in the conversation. As it discards past - tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. - - It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is - `[batch_size, num_heads, seq_len, head_dim]`. - - Parameters: - window_length (`int`): - The length of the context window. - num_sink_tokens (`int`): - The number of sink tokens. See the original paper for more information. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache - - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - SinkCache() - ``` - """ - - def __init__(self, window_length: int, num_sink_tokens: int) -> None: - super().__init__() - self.key_cache: List[torch.Tensor] = [] - self.value_cache: List[torch.Tensor] = [] - self.window_length = window_length - self.num_sink_tokens = num_sink_tokens - self.cos_sin_rerotation_cache = {} - self._cos_cache = None - self._sin_cache = None - self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen - - @staticmethod - def _rotate_half(x): - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def _apply_key_rotary_pos_emb( - self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor - ) -> torch.Tensor: - rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) - return rotated_key_states - - def _get_rerotation_cos_sin( - self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - if key_states.shape[-2] not in self.cos_sin_rerotation_cache: - # Upcast to float32 temporarily for better accuracy - cos = cos.to(torch.float32) - sin = sin.to(torch.float32) - - # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence - original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] - shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] - original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] - shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] - rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin - rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin - - self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( - rerotation_cos.to(key_states.dtype).unsqueeze(0), - rerotation_sin.to(key_states.dtype).unsqueeze(0), - ) - return self.cos_sin_rerotation_cache[key_states.shape[-2]] - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length - if len(self.key_cache) <= layer_idx: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_max_length(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states.""" - return self.window_length - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, - `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the - rotation as the tokens are shifted. - - Return: - A tuple containing the updated key and value states. - """ - # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models - # with partially rotated position embeddings, like Phi or Persimmon. - sin = cache_kwargs.get("sin") - cos = cache_kwargs.get("cos") - partial_rotation_size = cache_kwargs.get("partial_rotation_size") - using_rope = cos is not None and sin is not None - - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - # Update the sin/cos cache, which holds sin/cos values for all possible positions - if using_rope and layer_idx == 0: - # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove - # after all RoPE models have a llama-like cache utilization. - if cos.dim() == 2: - self._cos_cache = cos - self._sin_cache = sin - else: - if self._cos_cache is None: - self._cos_cache = cos[0, ...] - self._sin_cache = sin[0, ...] - elif self._cos_cache.shape[0] < self.window_length: - self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) - self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) - - # [bsz, num_heads, seq_len, head_dim] - if len(self.key_cache) <= layer_idx: - # Empty cache - self.key_cache.append(key_states) - self.value_cache.append(value_states) - - elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: - # Growing cache - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - else: - # Shifting cache - keys_to_keep = self.key_cache[layer_idx][ - :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : - ] - - # On RoPE models, we need to recompute the Key rotation as the tokens are shifted - if using_rope: - rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( - key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length] - ) - if partial_rotation_size is not None: - keys_to_keep, keys_pass = ( - keys_to_keep[..., :partial_rotation_size], - keys_to_keep[..., partial_rotation_size:], - ) - keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin) - if partial_rotation_size is not None: - keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) - - # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens - sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] - self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) - - sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] - values_to_keep = self.value_cache[layer_idx][ - :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : - ] - self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - - def reset_cache(self): - self.__init__(self.window_length, self.num_sink_tokens) \ No newline at end of file