diff --git a/modules/custom_operations/tests/requirements.txt b/modules/custom_operations/tests/requirements.txt index f115e7945..f64b33e51 100644 --- a/modules/custom_operations/tests/requirements.txt +++ b/modules/custom_operations/tests/requirements.txt @@ -1,5 +1,6 @@ torch onnx +onnxscript tensorboard pytest # open3d==0.16.0 - need to update with new release diff --git a/modules/genai_optimizations/README.md b/modules/genai_optimizations/README.md index b019e7278..55ff1388f 100644 --- a/modules/genai_optimizations/README.md +++ b/modules/genai_optimizations/README.md @@ -6,6 +6,7 @@ This module provides experimental optimizations for GenAI models in PyTorch. The - Text Generation Using LLMs - Visual language text generation +- Reasoning and Problem Solving ## Supported Generative AI Optimization Methods @@ -34,6 +35,14 @@ This module provides experimental optimizations for GenAI models in PyTorch. The Paper: https://arxiv.org/pdf/2306.14048 - **SnapKV Mode** – Modifies the *H2O* approach by computing token importance within a small sliding window of the most recent queries during the prefill stage, then reverting to the H2O strategy during decoding. The authors observed that only a small subset of prompt tokens is sufficient for accurate response generation. Paper: https://arxiv.org/pdf/2404.14469 + - **RKV Mode** - Computes token importance scores based on attention weights over a sliding window of the most recent queries during both the prefill and decode stages. Importance scores are stabilized using per-token max-pooling and then averaged across attention heads. + +Refined modes enhance standard eviction strategies by selecting the most representative tokens or blocks from the evictable (intermediate) region. These methods aim to balance contextual importance with redundancy reduction to optimize cache efficiency. If `refined_algorithm` is enabled but `refined_tokens` is not specified or set to 0, the number of refined tokens is determined dynamically as part of the intermediate token budget. Budget for primary algorithm is allocated by selecting the minimal number of tokens or groups that together capture at least 90% of the total attention mass, ensuring that all high-importance tokens are retained. For the remaining eviction budget, each token’s dissimilarity is computed relative to the already retained set, promoting information diversity and reducing redundancy. + + Supported refined modes: + - **KVCrush Mode** - Selects representative blocks based on diversity rather than raw importance. This is achieved by generating binary indicators for each token, constructing an anchor point (reference pattern) using one of several modes: `random`, `zeros`, `ones`, `mean`, `alternate`, and selecting blocks with the highest Hamming distance to the anchor point. + Paper: https://arxiv.org/pdf/2503.00022 + - **DiverseKV Mode** – Implements a dynamic redundancy scoring mechanism to identify and de-prioritize repetitive tokens based on cosine similarity of key vectors with already retained tokens. Key vectors are normalized, and cosine similarities are computed with diagonal values zeroed to avoid self-similarity. Similarities are thresholded on a per-head basis—only values greater than or equal to the mean similarity for each head are kept and then aggregated across heads. For the remaining eviction budget, each token or group's dissimilarity to already retained tokens or groups is calculated. Tokens/groups with the highest dissimilarity scores are retained, maximizing contextual diversity while reducing redundancy. ## Supported and tested models @@ -53,6 +62,12 @@ Multimodal Large Language Models: - [Qwen/Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) - [Qwen/Qwen2-VL-2B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct) +Large Reasoning Models: + +- [deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) +- [Qwen/Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B) +- [microsoft/Phi-4-mini-reasoning](https://huggingface.co/microsoft/Phi-4-mini-reasoning) + ## Prerequisites Before running algorithms, ensure you have **Python 3.10+** installed and set up your environment. diff --git a/modules/genai_optimizations/benchmarks/README.md b/modules/genai_optimizations/benchmarks/README.md index a2263fc3e..b3140d37a 100644 --- a/modules/genai_optimizations/benchmarks/README.md +++ b/modules/genai_optimizations/benchmarks/README.md @@ -10,6 +10,8 @@ This [example](./longbench.py) demonstrates how to evaluate and optimize LLMs us Sparse attention speeds up the prefill stage in LLMs by attending only to the most relevant query-key blocks. Static patterns like Tri-Shape and dynamic mechanisms like XAttention reduce memory and computation without significant accuracy loss, enabling efficient handling of long prompts. +KV-Cache Token Eviction accelerates the decoding stage in LLMs by removing less important cached tokens while preserving those essential for contextual understanding, allowing efficient long-sequence inference under constrained memory. + ### Run Example ```bash @@ -100,3 +102,32 @@ This will automatically: - Evaluate the model and report the score + +
+Large Reasoning Models Optimization Example: MATH500 and GSM8K Benchmarks + +This [example](./math500_gsm_bench.py) demonstrates how to evaluate and optimize LRMs using the KV-Cache Token Eviction algorithm. The example leverages [MATH500](https://huggingface.co/datasets/HuggingFaceH4/MATH-500) and [GSM8K](https://huggingface.co/datasets/openai/gsm8k) datasets. +MATH500 contains a subset of 500 problems from the [MATH](https://github.com/hendrycks/math) benchmark, originally introduced in OpenAI’s Let’s Verify Step by Step paper. The subset covers six domains: algebra, geometry, intermediate algebra, number theory, precalculus, and probability. +GSM8K (Grade School Math 8K) is a dataset of 8,500 high-quality, linguistically diverse grade-school math word problems. While the problems are conceptually simple, they often require multi-step reasoning, making them challenging for state-of-the-art language models due to the high diversity of problems. + + +### Run Example + +```bash +python math500_gsm_bench.py \ + --dataset MATH500 \ + --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ + --max_tokens 5000 \ + --max_examples 100 \ + --enable_eviction \ + --algorithm rkv \ + --granularity per_group \ + --intermediate_tokens 512 +``` +This will automatically: + +- Download the selected model and dataset +- Apply token eviction during the decoding stage +- Evaluate the model and report the score + +
diff --git a/modules/genai_optimizations/benchmarks/math500_gsm_bench.py b/modules/genai_optimizations/benchmarks/math500_gsm_bench.py new file mode 100644 index 000000000..f2f99d966 --- /dev/null +++ b/modules/genai_optimizations/benchmarks/math500_gsm_bench.py @@ -0,0 +1,308 @@ +# Copyright (C) 2018-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# This logic is largely copied from the +# - https://github.com/microsoft/ProphetNet/tree/master/CRITIC +# - https://github.com/openai/prm800k +# - https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py +# - https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py +# - https://github.com/VITA-Group/SEAL/tree/main + +import argparse +import json +import os +import random +import re +from collections import Counter +from contextlib import ExitStack + +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer + +from utils import add_attention_args, add_token_eviction_args +from utils import get_eviction_patcher, get_sparse_attention_patcher + +from reasoning_parser import extract_answer +from reasoning_parser import parallel_math_equal +from reasoning_parser import strip_string + +# disable tokenizer parallelism warnings +os.environ["TOKENIZERS_PARALLELISM"] = "false" +OUTPUT_LENGTHS = [] + + +def run_evaluation(res_path, save=False, k=None, output_dir=None): + with open(res_path) as f: + lines = f.readlines() + data = [json.loads(line) for line in lines] + + for example in tqdm(data): + if "model_generation" not in example: + example["model_generation"] = example["model_output"] + if k is not None: + example["model_generation"] = example["model_generation"][:k] + gt_cot = example["answer"] + gt_ans = extract_answer(gt_cot, data_name="omni-math") + gt_cot = str(gt_cot).strip() + gt_ans = strip_string(gt_ans, skip_unit=False) + all_pred = [extract_answer(p, data_name="omni-math") for p in example["model_generation"]] + all_pred = [strip_string(p, skip_unit=False) for p in all_pred] + all_eval = parallel_math_equal(all_pred, gt_ans, timeout=5) + effective_pred = [p for p, o in zip(all_pred, example["model_generation"]) if "boxed" in o] + if len(effective_pred) == 0: + effective_pred = all_pred + counter = Counter(effective_pred) + pred = counter.most_common(1)[0][0] + index = all_pred.index(pred) + eval = all_eval[index] + example["all_pred"] = all_pred + example["all_eval"] = all_eval + example["mv_pred"] = pred + example["mv_eval"] = eval + example["mv_index"] = index + + acc = sum([example["mv_eval"] for example in data]) / len(data) + print(f"Accuracy: {acc:.3f}") + + correct_avg_len = [] + incorrect_avg_len = [] + + for i, example in enumerate(data): + if example["mv_eval"]: + correct_avg_len.append(OUTPUT_LENGTHS[i]) + else: + incorrect_avg_len.append(OUTPUT_LENGTHS[i]) + + if len(correct_avg_len) != 0: + print(f"Correct avg len: {sum(correct_avg_len) / len(correct_avg_len):.2f}", end=", ") + if len(incorrect_avg_len) != 0: + print(f"Incorrect avg len: {sum(incorrect_avg_len) / len(incorrect_avg_len):.2f}") + + if save: + out_file = os.path.join(output_dir, "math_eval.jsonl") + with open(out_file, "w") as f: + for example in data: + f.write(json.dumps(example) + "\n") + + metric_file = os.path.join(output_dir, "metrics.json") + with open(metric_file, "w") as f: + json.dump({"acc": acc}, f) + + +def trim_output(output): + instruction_prefix = "Answer the following question" + question_prefix = "Question:" + comment_prefix = "Comment:" # for some reason, Llama 13B likes to generate these comments indefinitely + + for prefix in [instruction_prefix, question_prefix, comment_prefix]: + if prefix in output: + output = output.split(prefix)[0] + + return output + + +def extract_box(pred_str): + ans = pred_str.split("boxed")[-1] + if len(ans) == 0: + return "" + elif ans[0] == "{": + stack = 1 + a = "" + for c in ans[1:]: + if c == "{": + stack += 1 + a += c + elif c == "}": + stack -= 1 + if stack == 0: + break + a += c + else: + a += c + else: + a = ans.split("$")[0].strip() + + return a + + +def prepare_dataset(dataset, max_samples=None): + test_data = [] + if dataset == "MATH500": + data = load_dataset("HuggingFaceH4/MATH-500", split="test") + for example in data: + gt = extract_box(example["solution"]) + test_data.append( + { + "question": example["problem"], + "answer": example["solution"], + "gt": gt, + } + ) + elif dataset == "GSM": + data_path = "gsm.jsonl" + + if not os.path.exists(data_path): + import requests + url = "https://raw.githubusercontent.com/VITA-Group/SEAL/main/data/gsm/test.jsonl" + response = requests.get(url) + response.raise_for_status() + with open(data_path, "w", encoding="utf-8") as f: + f.write(response.text) + print(f"Downloaded and saved to '{data_path}'.") + + with open(data_path) as fin: + for line in fin: + example = json.loads(line) + answer = example["answer"].split("####")[1].strip() + answer = re.sub(r"(\d),(\d)", r"\1\2", answer) + test_data.append( + { + "question": example["question"], + "answer": example["answer"].split("####")[0].strip(), + "gt": answer, + } + ) + + if max_samples and len(test_data) > max_samples: + test_data = test_data[:max_samples] + + return test_data + + +def main(args): + random.seed(42) + + test_data = prepare_dataset(args.dataset, max_samples=args.max_examples) + + tokenizer = AutoTokenizer.from_pretrained(args.model) + # set pad token to eos token if pad token is not set (as is the case for llama models) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + + contexts = [] + if args.use_custom_attention: + sparse_attn = get_sparse_attention_patcher(args) + contexts.append(sparse_attn) + + if args.enable_eviction: + token_eviction = get_eviction_patcher(args) + contexts.append(token_eviction) + + prefix = ( + "Answer the following questions. You should think step-by-step and put your final answer within \\boxed{}.\n" + ) + prompts = [] + for example in test_data: + prompt = prefix + "Question: " + example["question"].strip() + "\nAnswer: " + if not args.omit_chat_template: + if "deepseek" in args.model: + messages = [{"role": "user", "content": prefix + "Question: " + example["question"].strip()}] + else: + messages = [ + {"role": "system", "content": prefix}, + {"role": "user", "content": "Question: " + example["question"].strip()}, + ] + prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + if not args.keep_bos and tokenizer.bos_token is not None and prompt.startswith(tokenizer.bos_token): + prompt = prompt[len(tokenizer.bos_token) :] + prompts.append(prompt) + + kwargs = {"temperature": None, "top_p": None, "top_k": None} + # force attn_implementation="eager" when using token eviction without custom attention + if args.enable_eviction and not args.use_custom_attention: + kwargs["attn_implementation"] = "eager" + + model = AutoModelForCausalLM.from_pretrained( + args.model, + trust_remote_code=True, + device_map="auto", + token=os.environ.get("HF_TOKEN", None), + **kwargs + ) + model.eval() + + contexts = [] + if args.use_custom_attention: + sparse_attn = get_sparse_attention_patcher(args) + contexts.append(sparse_attn) + + if args.enable_eviction: + token_eviction = get_eviction_patcher(args) + contexts.append(token_eviction) + + outputs = [] + avg_prompt_len = [] + with ExitStack() as stack: + for ctx in contexts: + if ctx is not None: + stack.enter_context(ctx(model)) + + for prompt in tqdm(prompts): + tokenized_batch = tokenizer(prompt, return_tensors="pt", padding=True) + tokenized_batch = {k: v.to(model.device) for k, v in tokenized_batch.items()} + avg_prompt_len.append(tokenized_batch["input_ids"].shape[1]) + + output = model.generate( + **tokenized_batch, + do_sample=False, + max_new_tokens=args.max_tokens, + use_cache=True, + pad_token_id=tokenizer.eos_token_id, + ) + OUTPUT_LENGTHS.append(output.shape[1]) + output = [tokenizer.decode(o[avg_prompt_len[-1]:], skip_special_tokens=True) for o in output] + outputs.extend(output) + + outputs = [[trim_output(o)] for o in outputs] + print(f"Average prompt length: {sum(avg_prompt_len) / len(avg_prompt_len):.2f}") + print(f"Average length: {sum(OUTPUT_LENGTHS) / len(OUTPUT_LENGTHS):.2f}") + + predictions = [ + { + "prompt": prompt, + "problem": example["question"], + "answer": example["gt"], + "solution": example["answer"], + "model_generation": output, + } + for example, output, prompt in zip(test_data, outputs, prompts) + ] + + with open(os.path.join(args.save_dir, "predictions.jsonl"), "w") as fout: + for prediction in predictions: + fout.write(json.dumps(prediction) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, required=True) + parser.add_argument("--dataset", type=str, default="MATH500", choices=["MATH500", "GSM"]) + parser.add_argument("--max_examples", type=int, default=None) + parser.add_argument("--start", type=int, default=None) + parser.add_argument("--save_dir", type=str, default="results") + parser.add_argument("--max_tokens", type=int, default=5000) + parser.add_argument("--omit_chat_template", action="store_true") + parser.add_argument("--keep_bos", action="store_true") + + add_attention_args(parser) + add_token_eviction_args(parser) + args = parser.parse_args() + + args.save_dir = os.path.join(args.save_dir, args.dataset) + if args.keep_bos: + args.save_dir = args.save_dir + "_keep_bos" + + if args.max_examples or args.start: + start = 0 if args.start is None else args.start + end = start + args.max_examples if args.max_examples is not None else -1 + args.save_dir = os.path.join(args.save_dir, f"{start}_{end}") + + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + + print(f"Results will be saved to {args.save_dir}") + main(args) + run_evaluation(os.path.join(args.save_dir, "predictions.jsonl"), output_dir=args.save_dir) diff --git a/modules/genai_optimizations/benchmarks/reasoning_parser.py b/modules/genai_optimizations/benchmarks/reasoning_parser.py new file mode 100644 index 000000000..b8b72d739 --- /dev/null +++ b/modules/genai_optimizations/benchmarks/reasoning_parser.py @@ -0,0 +1,885 @@ +# Copyright (C) 2018-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# This logic is largely copied from the +# - https://github.com/microsoft/ProphetNet/tree/master/CRITIC +# - https://github.com/openai/prm800k +# - https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py +# - https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py +# - https://github.com/VITA-Group/SEAL/tree/main + +import multiprocessing +import queue +import re +from math import isclose +from typing import Union + +import regex +from latex2sympy2 import latex2sympy +from sympy import N +from sympy import simplify +from sympy.parsing.latex import parse_latex +from sympy.parsing.sympy_parser import parse_expr +from word2number import w2n + + +def _fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if len(substr) > 0 and substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except Exception: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +def _fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + if "sqrt" not in a: + a = int(a) + if "sqrt" not in b: + b = int(b) + assert string == f"{a}/{b}" + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except Exception: + return string + + +def _fix_sqrt(string): + _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) + return _string + + +def convert_word_number(text: str) -> str: + try: + string = str(w2n.word_to_num(text)) + return string + except ValueError: + return text + + +# units mainly from MathQA +unit_texts = [ + "east", + "degree", + "mph", + "kmph", + "ft", + "m sqaure", + " m east", + "sq m", + "deg", + "mile", + "q .", + "monkey", + "prime", + "ratio", + "profit of rs", + "rd", + "o", + "gm", + "p . m", + "lb", + "tile", + "per", + "dm", + "lt", + "gain", + "ab", + "way", + "west", + "a .", + "b .", + "c .", + "d .", + "e .", + "f .", + "g .", + "h .", + "t", + "a", + "h", + "no change", + "men", + "soldier", + "pie", + "bc", + "excess", + "st", + "inches", + "noon", + "percent", + "by", + "gal", + "kmh", + "c", + "acre", + "rise", + "a . m", + "th", + "π r 2", + "sq", + "mark", + "l", + "toy", + "coin", + "sq . m", + "gallon", + "° f", + "profit", + "minw", + "yr", + "women", + "feet", + "am", + "pm", + "hr", + "cu cm", + "square", + "v â € ™", + "are", + "rupee", + "rounds", + "cubic", + "cc", + "mtr", + "s", + "ohm", + "number", + "kmph", + "day", + "hour", + "minute", + "min", + "second", + "man", + "woman", + "sec", + "cube", + "mt", + "sq inch", + "mp", + "∏ cm ³", + "hectare", + "more", + "sec", + "unit", + "cu . m", + "cm 2", + "rs .", + "rs", + "kg", + "g", + "month", + "km", + "m", + "cm", + "mm", + "apple", + "liter", + "loss", + "yard", + "pure", + "year", + "increase", + "decrease", + "d", + "less", + "Surface", + "litre", + "pi sq m", + "s .", + "metre", + "meter", + "inch", +] + +unit_texts.extend([t + "s" for t in unit_texts]) + + +def strip_string(string, skip_unit=False): + string = str(string).strip() + # linebreaks + string = string.replace("\n", "") + + # right "." + string = string.rstrip(".") + + # remove inverse spaces + # replace \\ with \ + string = string.replace("\\!", "") + # string = string.replace("\\ ", "") + # string = string.replace("\\\\", "\\") + + # matrix + string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string) + string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string) + string = string.replace("bmatrix", "pmatrix") + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + string = string.replace("\\neq", "\\ne").replace("\\leq", "\\le").replace("\\geq", "\\ge") + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + string = string.replace("\\{", "{") + string = string.replace("\\}", "}") + + # Remove unit: miles, dollars if after is not none + _string = re.sub(r"\\text{.*?}$", "", string).strip() + if _string != "" and _string != string: + # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) + string = _string + + if not skip_unit: + # Remove unit: texts + for _ in range(2): + for unit_text in unit_texts: + # use regex, the prefix should be either the start of the string or a non-alphanumeric character + # the suffix should be either the end of the string or a non-alphanumeric character + _string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string) + if _string != "": + string = _string + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # remove dollar signs + string = string.replace("\\$", "") + string = string.replace("$", "") + string = string.replace("\\(", "").replace("\\)", "") + + # convert word number to digit + string = convert_word_number(string) + + # replace "\\text{...}" to "..." + string = re.sub(r"\\text\{(.*?)\}", r"\1", string) + for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]: + string = string.replace(key, "") + string = string.replace("\\emptyset", r"{}") + string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}") + + # remove percentage + string = string.replace("\\%", "") + string = string.replace("\%", "") + string = string.replace("%", "") + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + + # cdot + # string = string.replace("\\cdot", "") + if ( + string.startswith("{") + and string.endswith("}") + and string.isalnum() + or string.startswith("(") + and string.endswith(")") + and string.isalnum() + or string.startswith("[") + and string.endswith("]") + and string.isalnum() + ): + string = string[1:-1] + + # inf + string = string.replace("infinity", "\\infty") + if "\\infty" not in string: + string = string.replace("inf", "\\infty") + string = string.replace("\\inity", "\\infty") + + # and + string = string.replace("and", "") + string = string.replace("\\mathbf", "") + + # use regex to remove \mbox{...} + string = re.sub(r"\\mbox{.*?}", "", string) + + # quote + string.replace("'", "") + string.replace('"', "") + + # i, j + if "j" in string and "i" not in string: + string = string.replace("j", "i") + + # replace a.000b where b is not number or b is end, with ab, use regex + string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string) + string = re.sub(r"(\d+)\.0*$", r"\1", string) + + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + string = _fix_sqrt(string) + string = string.replace(" ", "") + + string = _fix_fracs(string) + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = _fix_a_slash_b(string) + + return string + + +direct_answer_trigger_for_fewshot = ("choice is", "answer is") + + +def choice_answer_clean(pred: str): + pred = pred.strip("\n") + + # Determine if this is ICL, if so, use \n\n to split the first chunk. + ICL = False + for trigger in direct_answer_trigger_for_fewshot: + if pred.count(trigger) > 1: + ICL = True + if ICL: + pred = pred.split("\n\n")[0] + + # Split the trigger to find the answer. + preds = re.split("|".join(direct_answer_trigger_for_fewshot), pred) + if len(preds) > 1: + answer_flag = True + pred = preds[-1] + else: + answer_flag = False + + pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":") + + # Clean the answer based on the dataset + tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) + if tmp: + pred = tmp + else: + pred = [pred.strip().strip(".")] + + if len(pred) == 0: + pred = "" + else: + if answer_flag: + # choose the first element in list ... + pred = pred[0] + else: + # choose the last e + pred = pred[-1] + + # Remove the period at the end, again! + pred = pred.rstrip(".").rstrip("/") + + return pred + + +def extract_answer(pred_str, data_name, use_last_number=True): + if data_name == "omni-math": + use_last_number = False + pred_str = pred_str.replace("\u043a\u0438", "") + if data_name in ["mmlu_stem", "sat_math", "aqua", "gaokao2023"]: + # TODO check multiple choice + return choice_answer_clean(pred_str) + + if "final answer is $" in pred_str and "$. I hope" in pred_str: + # minerva_math + tmp = pred_str.split("final answer is $", 1)[1] + pred = tmp.split("$. I hope", 1)[0].strip() + elif "boxed" in pred_str: + ans = pred_str.split("boxed")[-1] + if len(ans) == 0: + return "" + elif ans[0] == "{": + stack = 1 + a = "" + for c in ans[1:]: + if c == "{": + stack += 1 + a += c + elif c == "}": + stack -= 1 + if stack == 0: + break + a += c + else: + a += c + else: + a = ans.split("$")[0].strip() + pred = a + elif "he answer is" in pred_str: + pred = pred_str.split("he answer is")[-1].strip() + elif "final answer is" in pred_str: + pred = pred_str.split("final answer is")[-1].strip() + elif "答案是" in pred_str: + # Handle Chinese few-shot multiple choice problem answer extraction + pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip() + else: # use the last number + if use_last_number: + pattern = "-?\d*\.?\d+" + pred = re.findall(pattern, pred_str.replace(",", "")) + if len(pred) >= 1: + pred = pred[-1] + else: + pred = "" + else: + # pred = "" + pred = pred_str + + # choice answer + if data_name in ["sat_math", "aqua"] or "mmlu" in data_name: + tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) + if tmp: + pred = tmp[-1] + else: + pred = pred.strip().strip(".") + + # multiple line + # pred = pred.split("\n")[0] + pred = re.sub(r"\n\s*", "", pred) + if pred != "" and pred[0] == ":": + pred = pred[1:] + if pred != "" and pred[-1] == ".": + pred = pred[:-1] + if pred != "" and pred[-1] == "/": + pred = pred[:-1] + pred = strip_string(pred, skip_unit=data_name in ["carp_en", "minerva_math"]) + return pred + + +STRIP_EXCEPTIONS = ["carp_en", "minerva_math"] + + +def parse_ground_truth(example, data_name): + if "gt_cot" in example and "gt" in example: + if data_name in ["math"]: + gt_ans = extract_answer(example["gt_cot"], data_name) + elif data_name == "omni-math": + if "boxed" not in example["gt_cot"]: + example["gt_cot"] = "\\boxed" + "{" + example["gt_cot"] + "}" + gt_ans = extract_answer(example["gt_cot"], data_name, use_last_number=False) + elif data_name in STRIP_EXCEPTIONS: + gt_ans = example["gt"] + else: + gt_ans = strip_string(example["gt"]) + return example["gt_cot"], gt_ans + + # parse ground truth + if data_name in ["math", "minerva_math", "omni-math"]: + gt_cot = example["solution"] + gt_ans = extract_answer(gt_cot, data_name) + elif data_name == "gsm8k": + gt_cot, gt_ans = example["answer"].split("####") + elif data_name == "svamp": + gt_cot, gt_ans = example["Equation"], example["Answer"] + elif data_name == "asdiv": + gt_cot = example["formula"] + gt_ans = re.sub(r"\(.*?\)", "", example["answer"]) + elif data_name == "mawps": + gt_cot, gt_ans = None, example["target"] + elif data_name == "tabmwp": + gt_cot = example["solution"] + gt_ans = example["answer"] + if example["ans_type"] in ["integer_number", "decimal_number"]: + if "/" in gt_ans: + gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1]) + elif "," in gt_ans: + gt_ans = float(gt_ans.replace(",", "")) + elif "%" in gt_ans: + gt_ans = float(gt_ans.split("%")[0]) / 100 + else: + gt_ans = float(gt_ans) + elif data_name == "carp_en": + gt_cot, gt_ans = example["steps"], example["answer"] + elif data_name == "mmlu_stem": + abcd = "ABCD" + gt_cot, gt_ans = None, abcd[example["answer"]] + elif data_name == "sat_math": + gt_cot, gt_ans = None, example["Answer"] + elif data_name == "aqua": + gt_cot, gt_ans = None, example["correct"] + elif data_name in ["gaokao2023en", "college_math", "gaokao_math_cloze"]: + gt_cot, gt_ans = None, example["answer"].replace("$", "").strip() + elif data_name == "gaokao_math_qa": + gt_cot, gt_ans = None, example["label"] + elif data_name in ["gaokao2024_mix", "cn_middle_school"]: + if len(example["choice_answer"]) > 0: + gt_cot, gt_ans = None, example["choice_answer"] + else: + gt_cot, gt_ans = None, example["answer"] + elif data_name == "olympiadbench": + gt_cot, gt_ans = None, example["final_answer"][0].strip("$") + elif data_name in [ + "aime24", + "amc23", + "cmath", + "gaokao2024_I", + "gaokao2024_II", + "imo2024", + ]: + gt_cot, gt_ans = None, example["answer"] + else: + error_msg = f"`{data_name}`" + raise NotImplementedError(error_msg) + # post process + gt_cot = str(gt_cot).strip() + if data_name not in STRIP_EXCEPTIONS: + gt_ans = strip_string(gt_ans, skip_unit=data_name == "carp_en") + else: + gt_ans = gt_ans.replace("\\neq", "\\ne").replace("\\leq", "\\le").replace("\\geq", "\\ge") + return gt_cot, gt_ans + + +# --------------Evaluation Utils----------------- + + +def postprocess_choice_answer(pred: str): + pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":") + # Clean the answer based on the dataset + tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) + if tmp: + pred = tmp + else: + pred = [pred.strip().strip(".")] + pred = pred[-1] + # Remove the period at the end, again! + pred = pred.rstrip(".").rstrip("/") + return pred + + +def parse_digits(num): + num = regex.sub(",", "", str(num)) + try: + return float(num) + except ValueError: + if num.endswith("%"): + num = num[:-1] + if num.endswith("\\"): + num = num[:-1] + try: + return float(num) / 100 + except ValueError: + pass + return None + + +def is_digit(num): + # paired with parse_digits + return parse_digits(num) is not None + + +def str_to_pmatrix(input_str): + input_str = input_str.strip() + matrix_str = re.findall(r"\{.*,.*\}", input_str) + pmatrix_list = [] + + for m in matrix_str: + m = m.strip("{}") + pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}" + pmatrix_list.append(pmatrix) + + return ", ".join(pmatrix_list) + + +def math_equal( + prediction: Union[bool, float, str], + reference: Union[float, str], + include_percentage: bool = True, + is_close: bool = True, + timeout: bool = False, +) -> bool: + """ + Exact match of math if and only if: + 1. numerical equal: both can convert to float and are equal + 2. symbolic equal: both can convert to sympy expression and are equal + """ + # print("Judge:", prediction, reference) + if prediction is None or reference is None: + return False + if str(prediction.strip().lower()) == str(reference.strip().lower()): + return True + if reference in ["A", "B", "C", "D", "E"] and postprocess_choice_answer(prediction) == reference: + return True + + try: # 1. numerical equal + if is_digit(prediction) and is_digit(reference): + prediction = parse_digits(prediction) + reference = parse_digits(reference) + # number questions + if include_percentage: + gt_result = [reference / 100, reference, reference * 100] + else: + gt_result = [reference] + for item in gt_result: + try: + if is_close: + if numeric_equal(prediction, item): + return True + else: + if item == prediction: + return True + except Exception: + continue + return False + except Exception: + pass + + if not prediction and prediction not in [0, False]: + return False + + # 2. symbolic equal + reference = str(reference).strip() + prediction = str(prediction).strip() + + # pmatrix (amps) + if "pmatrix" in prediction and "pmatrix" not in reference: + reference = str_to_pmatrix(reference) + + # deal with [], (), {} + pred_str, ref_str = prediction, reference + if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or ( + prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[") + ): + pred_str = pred_str.strip("[]()") + ref_str = ref_str.strip("[]()") + for s in ["{", "}", "(", ")"]: + ref_str = ref_str.replace(s, "") + pred_str = pred_str.replace(s, "") + if pred_str.lower() == ref_str.lower(): + return True + + # [a, b] vs. [c, d], return a==c and b==d + if ( + regex.match(r"(\(|\[).+(\)|\])", prediction) is not None + and regex.match(r"(\(|\[).+(\)|\])", reference) is not None + ): + pred_parts = prediction[1:-1].split(",") + ref_parts = reference[1:-1].split(",") + if len(pred_parts) == len(ref_parts) and all( + [math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))] + ): + return True + if ( + (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) + and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) + and (reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) + and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")) + ): + pred_lines = [ + line.strip() + for line in prediction[len("\\begin{pmatrix}") : -len("\\end{pmatrix}")].split("\\\\") + if line.strip() + ] + ref_lines = [ + line.strip() + for line in reference[len("\\begin{pmatrix}") : -len("\\end{pmatrix}")].split("\\\\") + if line.strip() + ] + matched = True + if len(pred_lines) == len(ref_lines): + for pred_line, ref_line in zip(pred_lines, ref_lines): + pred_parts = pred_line.split("&") + ref_parts = ref_line.split("&") + if len(pred_parts) == len(ref_parts): + if not all( + [ + math_equal( + pred_parts[i], + ref_parts[i], + include_percentage, + is_close, + ) + for i in range(len(pred_parts)) + ] + ): + matched = False + break + else: + matched = False + if not matched: + break + else: + matched = False + if matched: + return True + + if prediction.count("=") == 1 and reference.count("=") == 1: + pred = prediction.split("=") + pred = f"{pred[0].strip()} - ({pred[1].strip()})" + ref = reference.split("=") + ref = f"{ref[0].strip()} - ({ref[1].strip()})" + if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): + return True + elif prediction.count("=") == 1 and len(prediction.split("=")[0].strip()) <= 2 and "=" not in reference: + if math_equal(prediction.split("=")[1], reference, include_percentage, is_close): + return True + elif reference.count("=") == 1 and len(reference.split("=")[0].strip()) <= 2 and "=" not in prediction: + if math_equal(prediction, reference.split("=")[1], include_percentage, is_close): + return True + + # symbolic equal with sympy + if timeout: + if call_with_timeout(symbolic_equal_process, prediction, reference): + return True + else: + if symbolic_equal(prediction, reference): + return True + + return False + + +def numeric_equal(prediction: float, reference: float): + # Note that relative tolerance has significant impact + # on the result of the synthesized GSM-Hard dataset + # if reference.is_integer(): + # return isclose(reference, round(prediction), abs_tol=1e-4) + # else: + # prediction = round(prediction, len(str(reference).split(".")[-1])) + return isclose(reference, prediction, rel_tol=1e-4) + + +def symbolic_equal(a, b): + def _parse(s): + for f in [parse_latex, parse_expr, latex2sympy]: + try: + return f(s.replace("\\\\", "\\")) + except Exception: + try: + return f(s) + except Exception: + pass + return s + + a = _parse(a) + b = _parse(b) + + # direct equal + try: + if str(a) == str(b) or a == b: + return True + except Exception: + pass + + # simplify equal + try: + if a.equals(b) or simplify(a - b) == 0: + return True + except Exception: + pass + + # equation equal + try: + if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)): + return True + except Exception: + pass + + try: + if numeric_equal(float(N(a)), float(N(b))): + return True + except Exception: + pass + + # matrix + try: + # if a and b are matrix + if a.shape == b.shape: + _a = a.applyfunc(lambda x: round(x, 3)) + _b = b.applyfunc(lambda x: round(x, 3)) + if _a.equals(_b): + return True + except Exception: + pass + + return False + + +def symbolic_equal_process(a, b, output_queue): + result = symbolic_equal(a, b) + output_queue.put(result) + + +def call_with_timeout(func, *args, timeout=1, **kwargs): + output_queue = multiprocessing.Queue() + process_args = args + (output_queue,) + process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) + process.start() + process.join(timeout) + + if process.is_alive(): + process.terminate() + process.join() + return False + + return output_queue.get() + + +def math_equal_with_timeout(pred, gt_ans, timeout): + def target(result_queue): + try: + result_queue.put(math_equal(pred, gt_ans)) + except Exception as e: + result_queue.put(e) + + result_queue = multiprocessing.Queue() + process = multiprocessing.Process(target=target, args=(result_queue,)) + process.start() + + process.join(timeout) + + if process.is_alive(): + print(f"Timeout occurred for prediction: {pred}") + process.terminate() + process.join() + return False + + try: + result = result_queue.get_nowait() + except queue.Empty: + print("Result queue timed out") + return False + + if isinstance(result, Exception): + print(f"Error occurred: {result}") + return False + + return result + + +def parallel_math_equal(all_pred, gt_ans, timeout=20): + results = [] + for pred in all_pred: + results.append(math_equal_with_timeout(pred, gt_ans, timeout)) + return results diff --git a/modules/genai_optimizations/benchmarks/utils.py b/modules/genai_optimizations/benchmarks/utils.py index 7c1ff05cc..275b0b1e7 100644 --- a/modules/genai_optimizations/benchmarks/utils.py +++ b/modules/genai_optimizations/benchmarks/utils.py @@ -1,7 +1,7 @@ # Copyright (C) 2018-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from genai_opt import SparseAttention -from genai_opt import KVCacheCompressionMode, KVCacheCompressionParameters, KVCacheCompressor +from genai_opt import KVCacheCompressionMode, KVCacheCompressionParameters, KVCacheCompressor, KVCacheRefinedSelection def add_visual_pruning_args(parser): group = parser.add_argument_group("Visual Token Pruning Arguments") @@ -34,7 +34,7 @@ def add_attention_args(parser): def add_token_eviction_args(parser): group = parser.add_argument_group("Token Eviction Arguments") group.add_argument("--enable_eviction", action="store_true", help="Enable token eviction") - group.add_argument("--algorithm", default="snapkv", choices=["snapkv", "h2o"], help="The KV cache eviction algorithm") + group.add_argument("--algorithm", default="snapkv", choices=["snapkv", "h2o", "rkv"], help="The KV cache eviction algorithm") group.add_argument("--granularity", default="per_group", choices=["per_token", "per_group"], help="Eviction granularity") group.add_argument( "--normalize_scores", @@ -51,6 +51,26 @@ def add_token_eviction_args(parser): group.add_argument("--recent_tokens", type=int, default=128, help="The number of most recent tokens to be retained") group.add_argument("--group_size", type=int, default=32, help="Group size for per-group eviction strategy") group.add_argument("--window_size", type=int, default=None, help="The size of the importance score aggregation window") + group.add_argument( + "--refined_algorithm", + type=str, + default=None, + choices=["kvcrush", "diversekv"], + help="The refined scoring strategy for selecting tokens within the intermediate region" + ) + group.add_argument( + "--refined_tokens", + type=int, + default=0, + help="The number of tokens within the intermediate region that will be selected using a refined scoring strategy" + ) + group.add_argument( + "--kvcrush_anchor", + type=str, + default="alternate", + choices=["random", "zeros", "ones", "mean", "alternate"], + help="The anchor point for the KVCrush algorithm" + ) return parser @@ -77,5 +97,8 @@ def get_eviction_patcher(args): intermediate_tokens=args.intermediate_tokens, normalize_scores=args.normalize_scores, window_size=args.window_size, + refined_algorithm=KVCacheRefinedSelection(args.refined_algorithm) if args.refined_algorithm else None, + refined_tokens=args.refined_tokens, + kvcrush_anchor=args.kvcrush_anchor, ) return KVCacheCompressor(eviction_parameters=params) diff --git a/modules/genai_optimizations/genai_opt/__init__.py b/modules/genai_optimizations/genai_opt/__init__.py index 66ccabb28..c6c3adf33 100644 --- a/modules/genai_optimizations/genai_opt/__init__.py +++ b/modules/genai_optimizations/genai_opt/__init__.py @@ -3,4 +3,9 @@ from genai_opt.visual_token_pruning import get_inputs_embeds from genai_opt.sparse_attention import SparseAttention -from genai_opt.token_eviction import KVCacheCompressionMode, KVCacheCompressionParameters, KVCacheCompressor +from genai_opt.token_eviction import ( + KVCacheCompressionMode, + KVCacheCompressionParameters, + KVCacheCompressor, + KVCacheRefinedSelection, +) diff --git a/modules/genai_optimizations/genai_opt/sparse_attention.py b/modules/genai_optimizations/genai_opt/sparse_attention.py index 07b31e5cd..0978a16be 100644 --- a/modules/genai_optimizations/genai_opt/sparse_attention.py +++ b/modules/genai_optimizations/genai_opt/sparse_attention.py @@ -16,6 +16,7 @@ from transformers.cache_utils import Cache from transformers.models.llama.modeling_llama import repeat_kv from transformers.models.llama.modeling_llama import apply_rotary_pos_emb +from transformers.models.phi3.modeling_phi3 import apply_rotary_pos_emb as phi3_apply_rotary_pos_emb from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb from block_sparse_attn import block_sparse_attn_func @@ -619,7 +620,7 @@ def qwen2_vl_forward( value_states=value_states, attention_mask=attention_mask, scaling=module.scaling, - dropout_p=module.attention_dropout if module.training else 0.0, + dropout=module.attention_dropout if module.training else 0.0, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() @@ -657,7 +658,91 @@ def llama_forward( key_states=key_states, value_states=value_states, attention_mask=attention_mask, - dropout_p=module.attention_dropout if module.training else 0.0, + dropout=module.attention_dropout if module.training else 0.0, + scaling=module.scaling, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = module.o_proj(attn_output) + return attn_output, attn_weights + + +def qwen3_forward( + module, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, module.head_dim) + + query_states = module.q_norm(module.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = module.k_norm(module.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = module.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, module.layer_idx, cache_kwargs) + + attn_output, attn_weights = module.attn_interface( + module, + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + dropout=module.attention_dropout if module.training else 0.0, + scaling=module.scaling, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = module.o_proj(attn_output) + return attn_output, attn_weights + + +def phi_forward( + module, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, module.head_dim) + + qkv = module.qkv_proj(hidden_states) + query_pos = module.config.num_attention_heads * module.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + module.num_key_value_heads * module.head_dim] + value_states = qkv[..., query_pos + module.num_key_value_heads * module.head_dim :] + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = phi3_apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, module.layer_idx, cache_kwargs) + + attn_output, attn_weights = module.attn_interface( + module, + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + dropout=module.attention_dropout if module.training else 0.0, scaling=module.scaling, ) @@ -672,6 +757,8 @@ def llama_forward( "LlamaForCausalLM": llama_forward, "MistralForCausalLM": llama_forward, "Qwen2ForCausalLM": llama_forward, + "Qwen3ForCausalLM": qwen3_forward, + "Phi3ForCausalLM": phi_forward, } def get_custom_attn_forward(model: PreTrainedModel): diff --git a/modules/genai_optimizations/genai_opt/token_eviction.py b/modules/genai_optimizations/genai_opt/token_eviction.py index 02eb692d4..4625a13ef 100644 --- a/modules/genai_optimizations/genai_opt/token_eviction.py +++ b/modules/genai_optimizations/genai_opt/token_eviction.py @@ -19,6 +19,12 @@ class KVCacheCompressionMode(Enum): H2O = "h2o" SNAPKV = "snapkv" + RKV = "rkv" + + +class KVCacheRefinedSelection(Enum): + KVCRUSH = "kvcrush" + DIVERSEKV = "diversekv" @dataclass @@ -40,6 +46,15 @@ class KVCacheCompressionParameters: :param intermediate_tokens: The number of tokens between the "start" and "recent" areas of KV cache that will be considered for eviction. :type intermediate_tokens: int + :param refined_algorithm: The refined scoring strategy for selecting tokens within the intermediate region. + :type refined_algorithm: KVCacheRefinedSelection + :param refined_tokens: The number of tokens within the intermediate region that will be selected + using a secondary - refined scoring strategy (e.g., KVCrush, DiverseKV algo). + If set to 0 (default), the entire intermediate region is processed using the primary selection method. + :type refined_tokens: int + :param kvcrush_anchor: The anchor point for the KVCrush algorithm, + which can be "alternate", "random", "zeros", "ones", or "mean". Defaults to "alternate". + :type kvcrush_anchor: str :param normalize_scores: Whether to normalize the attention scores by the number of times each token was attended to. :type normalize_scores: bool :param window_size: The size of the importance score aggregation window @@ -52,6 +67,9 @@ class KVCacheCompressionParameters: start_tokens: int = 32 recent_tokens: int = 128 intermediate_tokens: int = 512 + refined_algorithm: Optional[KVCacheRefinedSelection] = None + refined_tokens: int = 64 + kvcrush_anchor: str = "alternate" normalize_scores: bool = False window_size: Optional[int] = None @@ -67,11 +85,18 @@ def __init__(self, eviction_parameters: KVCacheCompressionParameters = KVCacheCo self.algorithm = eviction_parameters.algorithm self.window_size = eviction_parameters.window_size - if self.algorithm == KVCacheCompressionMode.SNAPKV and self.window_size is None: - self.window_size = 8 # Default window size for SnapKV + if self.algorithm != KVCacheCompressionMode.H2O and self.window_size is None: + logger.info(f"Set window size for {self.algorithm} to 8") + self.window_size = 8 # Default window size for SnapKV and RKV - self._scores = [] - self._cache_counter = [] if self.normalize_scores else None + self.refined_algorithm = eviction_parameters.refined_algorithm + self.refined_tokens = eviction_parameters.refined_tokens + self.adaptive_refined_size = self.refined_algorithm is not None and self.refined_tokens == 0 + self.kvcrush_anchor = eviction_parameters.kvcrush_anchor + self.attn_mass_threshold = 0.9 + + self._scores = {} + self._cache_counter = {} if self.normalize_scores else None self._validate_arguments() @@ -81,13 +106,20 @@ def _validate_arguments(self): Raises a ValueError at the end if any condition fails. """ error_msg = None - if self.start_tokens < 0 or self.recent_tokens < 0 or self.intermediate_tokens < 0: + if any(x < 0 for x in ( + self.start_tokens, + self.recent_tokens, + self.intermediate_tokens, + self.refined_tokens, + )): error_msg = "KV cache sizes must be non-negative integers." + elif self.refined_tokens > self.intermediate_tokens: + error_msg = "refined_tokens cannot be greater than intermediate_tokens." elif self.start_tokens + self.recent_tokens + self.intermediate_tokens <= 0: error_msg = "At least one of the KV cache sizes must be greater than zero." elif any( size % self.group_size != 0 - for size in (self.start_tokens, self.recent_tokens, self.intermediate_tokens) + for size in (self.start_tokens, self.recent_tokens, self.intermediate_tokens, self.refined_tokens) ): error_msg = "KV cache part sizes must be divisible by the group size." elif self.window_size is not None and self.algorithm == KVCacheCompressionMode.H2O: @@ -96,6 +128,11 @@ def _validate_arguments(self): error_msg = "Window size must be a positive integer if specified." elif self.granularity not in {"per_token", "per_group"}: error_msg = f"Granularity {self.granularity} is not supported. Supported granularities: 'per_token', 'per_group'." + elif self.kvcrush_anchor not in {"random", "zeros", "ones", "mean", "alternate"}: + error_msg = ( + f"Unknown KVCrush anchor: {self.kvcrush_anchor}. " + "Supported anchors: 'random', 'zeros', 'ones', 'mean', 'alternate'." + ) if error_msg: raise ValueError(error_msg) @@ -111,14 +148,17 @@ def clean(self): """ Resets the scores and cache counter. """ - self._scores = [] - self._cache_counter = [] if self.normalize_scores else None + self._scores = {} + self._cache_counter = {} if self.normalize_scores else None def aggregate_scores(self, layer_idx, attn_w): """ Updates the scores based on the attention weights. """ - layer_scores = self._scores[layer_idx] if len(self._scores) > layer_idx else None + if self.algorithm == KVCacheCompressionMode.RKV: + return self._update_rkv_scores(layer_idx, attn_w) + + layer_scores = self._scores.get(layer_idx, None) if self.window_size is not None: hh_score = attn_w[..., -self.window_size :, :].sum(dim=(0, 2)) # sum over batch and query length @@ -138,17 +178,20 @@ def aggregate_scores(self, layer_idx, attn_w): hh_score[:, :-num_new_tokens] += layer_scores layer_scores = hh_score - layer_counter = self._calculate_layer_counter(layer_idx, num_new_tokens) - self._update_layer_scores(layer_idx, layer_scores, layer_counter) + layer_counter = self._calculate_layer_counter(layer_idx, num_new_tokens, device=attn_w.device) + + self._scores[layer_idx] = layer_scores + if self.normalize_scores: + self._cache_counter[layer_idx] = layer_counter - def _calculate_layer_counter(self, layer_idx, num_new_tokens): + def _calculate_layer_counter(self, layer_idx, num_new_tokens, device): if not self.normalize_scores: return None new_count_size = num_new_tokens if self.window_size is not None: new_count_size = min(self.window_size, num_new_tokens) - new_counters = torch.arange(new_count_size, 0, -1) + new_counters = torch.arange(new_count_size, 0, -1, device=device) if len(self._cache_counter) > layer_idx: layer_counter = self._cache_counter[layer_idx] @@ -156,28 +199,185 @@ def _calculate_layer_counter(self, layer_idx, num_new_tokens): layer_counter = torch.cat((layer_counter, new_counters), dim=-1) else: if self.window_size is not None and num_new_tokens > self.window_size: - full_window = torch.full((num_new_tokens - self.window_size,), self.window_size) + full_window = torch.full((num_new_tokens - self.window_size,), self.window_size, device=device) layer_counter = torch.cat((full_window, new_counters), dim=0) else: layer_counter = new_counters return layer_counter - def _update_layer_scores(self, layer_idx, layer_scores, layer_counter=None): - if len(self._scores) <= layer_idx: - self._scores.append(layer_scores) - if self.normalize_scores: - self._cache_counter.append(layer_counter.to(layer_scores.device)) + def _update_rkv_scores(self, layer_idx: int, attn_w: torch.Tensor) -> None: + """ + Updates the scores for the decoding phase like in R-KV and RPC papers. + """ + hh_score = attn_w.sum(0) # Sum over batch, shape: (H, q_len, k_len) + + layer_scores = self._scores.get(layer_idx, None) + if layer_scores is None: + self._scores[layer_idx] = hh_score else: - self._scores[layer_idx] = layer_scores - if self.normalize_scores: - self._cache_counter[layer_idx] = layer_counter.to(layer_scores.device) + layer_scores = self._scores[layer_idx] + new_tokens = hh_score.shape[-1] - layer_scores.shape[-1] + self._scores[layer_idx] = torch.cat( + ( + F.pad(layer_scores, (0, new_tokens), mode="constant", value=0), + hh_score, + ), + dim=-2, + ) + + # Keep only the last `window_size` scores + if self._scores[layer_idx].shape[1] > self.window_size: + self._scores[layer_idx] = self._scores[layer_idx][:, -self.window_size :, :] def get_scores(self, layer_idx): - return ( - self._scores[layer_idx] - if not self.normalize_scores - else self._scores[layer_idx] / self._cache_counter[layer_idx] + if self._scores[layer_idx].dim() == 2: + return ( + self._scores[layer_idx] + if not self.normalize_scores + else self._scores[layer_idx] / self._cache_counter[layer_idx] + ) + + # Average over query length, shape: (H, k_len) + scores = self._scores[layer_idx].mean(dim=-2) + scores = F.max_pool1d( + scores, + kernel_size=7, + padding=7 // 2, + stride=1, ) + del self._scores[layer_idx] # Clear scores after retrieval + return scores.mean(0, keepdim=True)[:, self.start_tokens :] # Average over heads, shape: (1, k_len) + + def _get_keys_similarity(self, key_states): + keys_normalized = key_states / key_states.norm(dim=-1, keepdim=True) + similarity = torch.matmul(keys_normalized, keys_normalized.transpose(-1, -2)) + similarity = similarity[:, :, self.start_tokens :, self.start_tokens :] + # Aggregate over batch + similarity = similarity.mean(dim=0) + + for h in range(similarity.shape[0]): + similarity[h].fill_diagonal_(0.0) + + # Zero out values below mean similarity for each head + head_means = similarity.view(similarity.shape[0], -1).mean(dim=-1, keepdim=True) + thr = head_means.unsqueeze(-1) + similarity = torch.where(similarity >= thr, similarity, torch.zeros_like(similarity)) + + # Aggregate over heads + similarity = similarity.mean(dim=0) + return similarity + + def get_refined_indices(self, scores: torch.Tensor, kwargs: dict) -> torch.Tensor: + device = scores.device + refined_size = self.refined_tokens // self.group_size + if self.refined_algorithm == KVCacheRefinedSelection.KVCRUSH: + B, _ = scores.shape + if B != 1: + error_msg = "KVCacheCompressor with KVCrush algorithm supports only batch size of 1." + raise ValueError(error_msg) + + scores_flat = scores.view(-1) + refined_mask = scores_flat != float("-inf") + keepable_scores = scores_flat[refined_mask] + + # Binary vector: top 50% → 1, bottom 50% → 0 + num_zeros = keepable_scores.numel() // 2 + _, low_idx = torch.topk(keepable_scores, num_zeros, largest=False) + binary_vector = torch.ones_like(keepable_scores, dtype=torch.int) + binary_vector[low_idx] = 0 + + # Place binary_vector back into full-length binary tensor + full_binary = torch.zeros_like(scores_flat, dtype=torch.int, device=device) + full_binary[refined_mask] = binary_vector + + if self.granularity == "per_group": + full_binary = full_binary.view(-1, self.group_size) + num_groups = full_binary.shape[0] + + if self.kvcrush_anchor == "random": + anchor_point = torch.randint(0, 2, (num_groups,), device=device) + elif self.kvcrush_anchor == "zeros": + anchor_point = torch.zeros(num_groups, device=device) + elif self.kvcrush_anchor == "ones": + anchor_point = torch.ones(num_groups, device=device) + elif self.kvcrush_anchor == "mean": + mean_point = full_binary.float().mean(dim=1) + anchor_point = (mean_point > 0.5).int() + elif self.kvcrush_anchor == "alternate": + anchor_point = torch.zeros(num_groups, device=device) + anchor_point[1::2] = 1 + + hamming_distance = torch.sum( + full_binary != anchor_point.unsqueeze(1), dim=1 + ).float() # shape: [num_groups] + refined_group_mask = refined_mask.view(-1, self.group_size)[:, 0] + hamming_distance[~refined_group_mask] = float("-inf") # Set invalid indices to -inf + + sorted_dist_idx = torch.argsort(hamming_distance, descending=True) + + # Select evenly spaced indices using linspace (representative) + num_valid = keepable_scores.numel() // self.group_size + rep_indices = torch.linspace( + 0, num_valid - 1, steps=refined_size, dtype=torch.long, device=device + ) + assert rep_indices.numel() == refined_size + refined_topk = sorted_dist_idx[rep_indices] # shape: [refined_groups] + + return refined_topk + + # Anchor: shape [L] + if self.kvcrush_anchor == "random": + anchor = torch.randint_like(keepable_scores, low=0, high=2, device=device) + elif self.kvcrush_anchor == "zeros": + anchor = torch.zeros_like(keepable_scores, dtype=torch.int, device=device) + elif self.kvcrush_anchor == "ones": + anchor = torch.ones_like(keepable_scores, dtype=torch.int, device=device) + elif self.kvcrush_anchor == "mean": # equal to binary_vector in per-token case + error_msg = ( + "Mean anchor is not supported for KVCrush in per-token mode. " + "Please use 'random', 'zeros', 'ones' or 'alternate' anchors." + ) + raise ValueError(error_msg) + elif self.kvcrush_anchor == "alternate": + anchor = torch.zeros_like(keepable_scores, dtype=torch.int, device=device) + anchor[1::2] = 1 + + full_anchor = torch.zeros_like(scores_flat, dtype=torch.int) + full_anchor[refined_mask] = anchor + + # Hamming distance (1D): count bits different from anchor + hamming_distance = (full_binary != full_anchor).float() + hamming_distance[~refined_mask] = float("-inf") # Set invalid indices to -inf + + # Sort valid indices by distance to anchor (more diverse first) + sorted_dist_idx = torch.argsort(hamming_distance, descending=True) + + # Select evenly spaced indices using linspace (representative) + num_valid = keepable_scores.numel() + rep_indices = torch.linspace( + 0, num_valid - 1, steps=refined_size, dtype=torch.long, device=device + ) + assert rep_indices.numel() == refined_size + refined_topk = sorted_dist_idx[rep_indices].unsqueeze(0) # shape: [1, refined_tokens] + + elif self.refined_algorithm == KVCacheRefinedSelection.DIVERSEKV: + keys = kwargs.get("keys") + similarity = self._get_keys_similarity(keys) + n = scores.shape[-1] + similarity = similarity[:n, :n] # Only intermediate part + + selected_mask = scores[0] == float("-inf") + similarity_to_selected = similarity[:, selected_mask] + diversity = -similarity_to_selected.mean(dim=-1) # diverse = low sim to selected + + if self.granularity == "per_group": + diversity = diversity.view(-1, self.group_size).sum(dim=-1) + scores = scores.view(-1, self.group_size).sum(dim=-1) # Sum token scores inside group + # mask for already selected tokens (scores == -inf) + diversity[scores.view(-1) == float("-inf")] = float("-inf") + _, refined_topk = torch.topk(diversity, refined_size, dim=-1) + + return refined_topk def get_intermediate_page_scores(self): scores = self.get_scores() @@ -215,41 +415,71 @@ def _convert_group_indices(self, group_indices, seq_len): return indices + def _set_balanced_refined_size(self, intermediate_scores): + target_mass = self.attn_mass_threshold * intermediate_scores.sum(dim=-1) + vals, _ = torch.sort(intermediate_scores, descending=True, dim=-1) + cumsum = vals.cumsum(dim=-1) + cutoff = (cumsum >= target_mass).nonzero(as_tuple=False) + # Minimum number of groups to cover the target mass + k_min = cutoff[0].item() + 1 # +1 because indices are 0-based + self.refined_tokens = max(0, self.intermediate_tokens - k_min * self.group_size) + def get_remaining_indices(self, scores: torch.Tensor, kwargs: dict) -> torch.Tensor: """ Computes the indices of the keep tokens in the KV cache after compression. """ seq_len = self.start_tokens + scores.shape[-1] - if self.granularity == "per_token": - start_size = self.start_tokens - intermediate_size = self.intermediate_tokens - recent_size = self.recent_tokens - elif self.granularity == "per_group": - start_size = self.start_tokens // self.group_size - intermediate_size = self.intermediate_tokens // self.group_size - recent_size = self.recent_tokens // self.group_size + start_size = self.start_tokens // self.group_size + intermediate_size = self.intermediate_tokens // self.group_size + recent_size = self.recent_tokens // self.group_size + if self.granularity == "per_group": pad = scores.shape[-1] % self.group_size if pad: scores = F.pad(scores, (0, self.group_size - pad), mode="constant", value=0) - scores = scores.view(-1, self.group_size).sum(-1) # Sum token scores inside group + padded_scores = scores.view(-1, self.group_size).sum(-1) # Sum token scores inside group + else: + padded_scores = scores.squeeze(0) + intermediate_scores = padded_scores[:-recent_size] if recent_size > 0 else padded_scores + + if self.adaptive_refined_size: + self._set_balanced_refined_size(intermediate_scores) + refined_size = self.refined_tokens // self.group_size + coarse_size = intermediate_size - refined_size keep_groups = [] - size = scores.shape[0] if start_size > 0: keep_past = torch.arange(0, start_size, device=scores.device) keep_groups.append(keep_past) if intermediate_size > 0: - intermediate_scores = scores[:size - recent_size] + if coarse_size > 0: + _, keep_coarse = torch.topk(intermediate_scores, coarse_size, dim=-1) + keep_coarse = keep_coarse.sort().values + start_size + keep_groups.append(keep_coarse) + + if refined_size > 0: + refined_scores = scores[:, :len(intermediate_scores) * self.group_size] + + if coarse_size > 0: + coarse_idx = keep_coarse.unsqueeze(0) - start_size + mask = torch.zeros_like(refined_scores, dtype=torch.bool) - _, keep_coarse = torch.topk(intermediate_scores, intermediate_size, dim=-1) - keep_coarse = keep_coarse.sort().values + start_size - keep_groups.append(keep_coarse) + if self.granularity == "per_group": + coarse_idx = self._convert_group_indices( + coarse_idx, coarse_idx.shape[-1] * self.group_size + ) + + mask.scatter_(1, coarse_idx, True) # Ensure no OOB here + refined_scores = refined_scores.masked_fill(mask, float("-inf")) + + refined_topk = self.get_refined_indices(refined_scores, kwargs) + start_size + keep_groups.append(refined_topk) if recent_size > 0: + padded_len = padded_scores.shape[0] keep_recent = ( - torch.arange(size - recent_size, size, device=scores.device) + start_size + torch.arange(padded_len - recent_size, padded_len, device=scores.device) + start_size ) keep_groups.append(keep_recent) @@ -313,6 +543,9 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic seq_len = keys.shape[-2] if seq_len > self.max_cache_size: + if self.refined_algorithm == KVCacheRefinedSelection.DIVERSEKV: + kwargs["keys"] = keys + keys, values = self.compress(layer_idx, keys, values, kwargs) cache.layers[layer_idx].keys = keys diff --git a/modules/genai_optimizations/setup.py b/modules/genai_optimizations/setup.py index dfd8baeb2..d53945ba4 100644 --- a/modules/genai_optimizations/setup.py +++ b/modules/genai_optimizations/setup.py @@ -13,6 +13,8 @@ "bitsandbytes==0.47.0", "protobuf", "sentencepiece==0.2.1", + "latex2sympy2", + "word2number", ], }