Skip to content

Commit 5fe2ff3

Browse files
authored
Add longbench benchmark
1 parent 4100647 commit 5fe2ff3

10 files changed

+601
-2
lines changed

.flake8

+2
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,7 @@ max-line-length = 120
33
per-file-ignores =
44
__init__.py:F401
55
evaluation/infinite_bench/create_huggingface_dataset.py:E501
6+
evaluation/longbench/create_huggingface_dataset.py:E501
7+
evaluation/longbenchv2/create_huggingface_dataset.py:E501
68
# E203, W503 - black-compatible config
79
extend-ignore = E203, W503

evaluation/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ This directory contains a set of scripts to evaluate the performance of differen
55
- [RULER](ruler/README.md) ([hf link](https://huggingface.co/datasets/simonjegou/ruler))
66
- [Zero Scrolls](zero_scrolls/README.md) ([hf link](https://huggingface.co/datasets/simonjegou/zero_scrolls))
77
- [Infinitebench](infinite_bench/README.md) ([hf link](https://huggingface.co/datasets/MaxJeblick/InfiniteBench))
8+
- [longbench](longbench/README.md)([hf link](https://huggingface.co/datasets/Xnhyacinth/LongBench))
9+
- [longbench-v2](longbenchv2/README.md)([hf link](https://huggingface.co/datasets/Xnhyacinth/LongBench-v2))
810

911

1012
Please refer to the README of each dataset for more information on how the Hugging Face dataset was generated.

evaluation/evaluate.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from datasets import load_dataset
1111
from fire import Fire
1212
from infinite_bench.calculate_metrics import calculate_metrics as infinite_bench_scorer
13+
from longbench.calculate_metrics import calculate_metrics as longbench_scorer
14+
from longbench.calculate_metrics import calculate_metrics_e as longbench_scorer_e
15+
from longbenchv2.calculate_metrics import calculate_metrics as longbenchv2_scorer
1316
from loogle.calculate_metrics import calculate_metrics as loogle_scorer
1417
from ruler.calculate_metrics import calculate_metrics as ruler_scorer
1518
from tqdm import tqdm
@@ -19,6 +22,7 @@
1922
from kvpress import (
2023
AdaKVPress,
2124
ChunkKVPress,
25+
ComposedPress,
2226
CriticalAdaKVPress,
2327
CriticalKVPress,
2428
DuoAttentionPress,
@@ -39,13 +43,19 @@
3943
"ruler": "simonjegou/ruler",
4044
"zero_scrolls": "simonjegou/zero_scrolls",
4145
"infinitebench": "MaxJeblick/InfiniteBench",
46+
"longbench": "Xnhyacinth/LongBench",
47+
"longbench-e": "Xnhyacinth/LongBench",
48+
"longbench-v2": "Xnhyacinth/LongBench-v2",
4249
}
4350

4451
SCORER_DICT = {
4552
"loogle": loogle_scorer,
4653
"ruler": ruler_scorer,
4754
"zero_scrolls": zero_scrolls_scorer,
4855
"infinitebench": infinite_bench_scorer,
56+
"longbench": longbench_scorer,
57+
"longbench-e": longbench_scorer_e,
58+
"longbench-v2": longbenchv2_scorer,
4959
}
5060

5161
PRESS_DICT = {
@@ -66,6 +76,8 @@
6676
"tova": TOVAPress(),
6777
"duo_attention": DuoAttentionPress(),
6878
"chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20),
79+
"snap_think": ComposedPress([SnapKVPress(), ThinKPress()]),
80+
"full_kv": ExpectedAttentionPress(0.0),
6981
}
7082

7183

@@ -80,6 +92,7 @@ def evaluate(
8092
max_new_tokens: Optional[int] = None,
8193
max_context_length: Optional[int] = None,
8294
compress_questions: bool = False,
95+
key_channel_compression_ratio: float = 0.5,
8396
):
8497
"""
8598
Evaluate a model on a dataset using a press and save the results
@@ -106,6 +119,8 @@ def evaluate(
106119
Maximum number of tokens to use in the context. By default will use the maximum length supported by the model.
107120
compress_questions : bool, optional
108121
Whether to compress the questions as well, by default False
122+
key_channel_compression_ratio : float, optional
123+
key Channel Compression ratio for the channel press, by default 0.5
109124
"""
110125

111126
assert dataset in DATASET_DICT, f"No dataset found for {dataset}"
@@ -146,6 +161,20 @@ def evaluate(
146161

147162
if isinstance(press, (DuoAttentionPress)):
148163
press.head_compression_ratio = compression_ratio
164+
elif isinstance(press, (ComposedPress)):
165+
for ps in press.presses:
166+
if isinstance(ps, (ThinKPress)):
167+
ps.key_channel_compression_ratio = key_channel_compression_ratio
168+
save_filename = save_filename.with_name(
169+
save_filename.stem + f"__channel{key_channel_compression_ratio}" + save_filename.suffix
170+
)
171+
else:
172+
ps.compression_ratio = compression_ratio # type:ignore[attr-defined]
173+
elif isinstance(press, (ThinKPress)):
174+
press.key_channel_compression_ratio = key_channel_compression_ratio
175+
save_filename = save_filename.with_name(
176+
save_filename.stem + f"__channel{key_channel_compression_ratio}" + save_filename.suffix
177+
)
149178
else:
150179
press.compression_ratio = compression_ratio # type:ignore[attr-defined]
151180

@@ -165,7 +194,6 @@ def evaluate(
165194
pipe = pipeline("kv-press-text-generation", model=model, device_map="auto", model_kwargs=model_kwargs)
166195
else:
167196
pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)
168-
169197
# Run pipeline on each context
170198
df["predicted_answer"] = None
171199
df_context = df.groupby("context")

evaluation/longbench/README.md

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# longbench dataset
2+
3+
[longbench](https://github.com/THUDM/LongBench/tree/main/LongBench).
4+
5+
## Create Hugging Face dataset
6+
7+
The processed Hugging Face dataset for longbench can be found [here](https://huggingface.co/datasets/Xnhyacinth/LongBench). To reproduce this dataset, simply run the `create_huggingface_dataset.py` script.
+236
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import re
5+
import string
6+
from collections import Counter
7+
import numpy as np
8+
from rouge import Rouge
9+
10+
try:
11+
import jieba
12+
from fuzzywuzzy import fuzz
13+
except ImportError as e:
14+
missing_module = str(e).split()[-1].strip("'") # Extract missing module name
15+
print(
16+
f"Module '{missing_module}' not found. \
17+
If test Longbench, please install it using 'pip install {missing_module}'"
18+
)
19+
20+
21+
def calculate_metrics(df):
22+
predictions = df["predicted_answer"].tolist()
23+
answers = df["answers"].tolist()
24+
dataset = df["task"].tolist()[0]
25+
all_classes = df["all_classes"].tolist()[0]
26+
return scorer(dataset, predictions, answers, all_classes)
27+
28+
29+
def calculate_metrics_e(df):
30+
predictions = df["predicted_answer"].tolist()
31+
answers = df["answers"].tolist()
32+
dataset = df["task"].tolist()[0].removesuffix("-e")
33+
all_classes = df["all_classes"].tolist()[0]
34+
lengths = df["length"].tolist()
35+
return scorer_e(dataset, predictions, answers, lengths, all_classes)
36+
37+
38+
def scorer_e(dataset, predictions, answers, lengths, all_classes):
39+
scores = {"0-4k": [], "4-8k": [], "8k+": []} # type:ignore[var-annotated]
40+
for (prediction, ground_truths, length) in zip(predictions, answers, lengths):
41+
score = 0.0
42+
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
43+
prediction = prediction.lstrip("\n").split("\n")[0]
44+
for ground_truth in ground_truths:
45+
score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
46+
if length < 4000:
47+
scores["0-4k"].append(score)
48+
elif length < 8000:
49+
scores["4-8k"].append(score)
50+
else:
51+
scores["8k+"].append(score)
52+
for key in scores.keys():
53+
scores[key] = round(100 * np.mean(scores[key]), 2)
54+
return scores
55+
56+
57+
def scorer(dataset, predictions, answers, all_classes):
58+
total_score = 0.0
59+
for (prediction, ground_truths) in zip(predictions, answers):
60+
score = 0.0
61+
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
62+
prediction = prediction.lstrip("\n").split("\n")[0]
63+
for ground_truth in ground_truths:
64+
score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
65+
total_score += score
66+
return round(100 * total_score / len(predictions), 2)
67+
68+
69+
def normalize_answer(s):
70+
"""Lower text and remove punctuation, articles and extra whitespace."""
71+
72+
def remove_articles(text):
73+
return re.sub(r"\b(a|an|the)\b", " ", text)
74+
75+
def white_space_fix(text):
76+
return " ".join(text.split())
77+
78+
def remove_punc(text):
79+
exclude = set(string.punctuation)
80+
return "".join(ch for ch in text if ch not in exclude)
81+
82+
def lower(text):
83+
return text.lower()
84+
85+
return white_space_fix(remove_articles(remove_punc(lower(s))))
86+
87+
88+
def normalize_zh_answer(s):
89+
"""Lower text and remove punctuation, extra whitespace."""
90+
91+
def white_space_fix(text):
92+
return "".join(text.split())
93+
94+
def remove_punc(text):
95+
cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
96+
all_punctuation = set(string.punctuation + cn_punctuation)
97+
return "".join(ch for ch in text if ch not in all_punctuation)
98+
99+
def lower(text):
100+
return text.lower()
101+
102+
return white_space_fix(remove_punc(lower(s)))
103+
104+
105+
def count_score(prediction, ground_truth, **kwargs):
106+
numbers = re.findall(r"\d+", prediction)
107+
right_num = 0
108+
for number in numbers:
109+
if str(number) == str(ground_truth):
110+
right_num += 1
111+
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
112+
return float(final_score)
113+
114+
115+
def retrieval_score(prediction, ground_truth, **kwargs):
116+
pattern = r"Paragraph (\d+)"
117+
matches = re.findall(pattern, ground_truth)
118+
ground_truth_id = matches[0]
119+
numbers = re.findall(r"\d+", prediction)
120+
right_num = 0
121+
for number in numbers:
122+
if str(number) == str(ground_truth_id):
123+
right_num += 1
124+
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
125+
return float(final_score)
126+
127+
128+
def retrieval_zh_score(prediction, ground_truth, **kwargs):
129+
pattern = r"段落(\d+)"
130+
matches = re.findall(pattern, ground_truth)
131+
ground_truth_id = matches[0]
132+
numbers = re.findall(r"\d+", prediction)
133+
right_num = 0
134+
for number in numbers:
135+
if str(number) == str(ground_truth_id):
136+
right_num += 1
137+
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
138+
return float(final_score)
139+
140+
141+
def code_sim_score(prediction, ground_truth, **kwargs):
142+
all_lines = prediction.lstrip("\n").split("\n")
143+
prediction = ""
144+
for line in all_lines:
145+
if ("`" not in line) and ("#" not in line) and ("//" not in line):
146+
prediction = line
147+
break
148+
return fuzz.ratio(prediction, ground_truth) / 100
149+
150+
151+
def classification_score(prediction, ground_truth, **kwargs):
152+
em_match_list = []
153+
all_classes = kwargs["all_classes"]
154+
for class_name in all_classes:
155+
if class_name in prediction:
156+
em_match_list.append(class_name)
157+
for match_term in em_match_list:
158+
if match_term in ground_truth and match_term != ground_truth:
159+
em_match_list.remove(match_term)
160+
if ground_truth in em_match_list:
161+
score = 1.0 / len(em_match_list)
162+
else:
163+
score = 0.0
164+
return score
165+
166+
167+
def rouge_score(prediction, ground_truth, **kwargs):
168+
rouge = Rouge()
169+
try:
170+
scores = rouge.get_scores([prediction], [ground_truth], avg=True)
171+
except Exception as e:
172+
print(f"An error occurred: {e}")
173+
return 0.0
174+
return scores["rouge-l"]["f"]
175+
176+
177+
def rouge_zh_score(prediction, ground_truth, **kwargs):
178+
prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
179+
ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
180+
score = rouge_score(prediction, ground_truth)
181+
return score
182+
183+
184+
def f1_score(prediction, ground_truth, **kwargs):
185+
common = Counter(prediction) & Counter(ground_truth)
186+
num_same = sum(common.values())
187+
if num_same == 0:
188+
return 0
189+
precision = 1.0 * num_same / len(prediction)
190+
recall = 1.0 * num_same / len(ground_truth)
191+
f1 = (2 * precision * recall) / (precision + recall)
192+
return f1
193+
194+
195+
def qa_f1_score(prediction, ground_truth, **kwargs):
196+
normalized_prediction = normalize_answer(prediction)
197+
normalized_ground_truth = normalize_answer(ground_truth)
198+
199+
prediction_tokens = normalized_prediction.split()
200+
ground_truth_tokens = normalized_ground_truth.split()
201+
return f1_score(prediction_tokens, ground_truth_tokens)
202+
203+
204+
def qa_f1_zh_score(prediction, ground_truth, **kwargs):
205+
prediction_tokens = list(jieba.cut(prediction, cut_all=False))
206+
ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
207+
prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
208+
ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
209+
prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
210+
ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
211+
return f1_score(prediction_tokens, ground_truth_tokens)
212+
213+
214+
dataset2metric = {
215+
"narrativeqa": qa_f1_score,
216+
"qasper": qa_f1_score,
217+
"multifieldqa_en": qa_f1_score,
218+
"multifieldqa_zh": qa_f1_zh_score,
219+
"hotpotqa": qa_f1_score,
220+
"2wikimqa": qa_f1_score,
221+
"musique": qa_f1_score,
222+
"dureader": rouge_zh_score,
223+
"gov_report": rouge_score,
224+
"qmsum": rouge_score,
225+
"multi_news": rouge_score,
226+
"vcsum": rouge_zh_score,
227+
"trec": classification_score,
228+
"triviaqa": qa_f1_score,
229+
"samsum": rouge_score,
230+
"lsht": classification_score,
231+
"passage_retrieval_en": retrieval_score,
232+
"passage_count": count_score,
233+
"passage_retrieval_zh": retrieval_zh_score,
234+
"lcc": code_sim_score,
235+
"repobench-p": code_sim_score,
236+
}

0 commit comments

Comments
 (0)