Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions benchmarks/evals/e2e/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ class EvalConfigs:
"quest-256",
"quest-512",
"quest-1024",
"quest_optimized-64",
"quest_optimized-128",
"quest_optimized-256",
"quest_optimized-512",
"quest_optimized-1024",
"raas-64",
"raas-128",
"raas-256",
Expand Down Expand Up @@ -182,15 +187,18 @@ def load_model_for_approach(self, model_name: str, approach_name: str) -> AutoMo

model_config = self.configs.model_config
if model_config.model_type == "llama":
from transformers import LlamaForCausalLM

optimized = ("optimized" in approach_name)

if approach_name == "full" or "sink" in approach_name: # They differ only in cache type
from transformers import LlamaForCausalLM
model = LlamaForCausalLM.from_pretrained(
model_name,
device_map="cuda:0",
trust_remote_code=True,
)
elif "h2o" in approach_name:
from transformers import LlamaForCausalLM
from quest.models.h2o_llama import enable_h2o_attention_eval

model = LlamaForCausalLM.from_pretrained(
Expand All @@ -202,9 +210,25 @@ def load_model_for_approach(self, model_name: str, approach_name: str) -> AutoMo
model,
{"cache_budget": int(approach_name.split("-")[-1])},
)
elif "quest" in approach_name:
elif "quest" in approach_name and optimized:
from quest.models.quest_llama_optimized import LlamaForCausalLM
from quest.models.quest_llama_optimized import enable_quest_attention_eval
model = LlamaForCausalLM.from_pretrained(
model_name,
device_map="cuda:0",
trust_remote_code=True,
torch_dtype=torch.float16, # Use float16 for optimized version
)
enable_quest_attention_eval(
model,
{
"cache_budget": int(approach_name.split("-")[-1]),
"page_size": 16, # Fixed as stated in the paper
},
)
elif "quest" in approach_name and not optimized:
from transformers import LlamaForCausalLM
from quest.models.quest_llama import enable_quest_attention_eval

model = LlamaForCausalLM.from_pretrained(
model_name,
device_map="cuda:0",
Expand All @@ -218,6 +242,7 @@ def load_model_for_approach(self, model_name: str, approach_name: str) -> AutoMo
},
)
elif "raas" in approach_name:
from transformers import LlamaForCausalLM
from quest.models.raas_llama import enable_raas_attention_eval

model = LlamaForCausalLM.from_pretrained(
Expand Down Expand Up @@ -353,7 +378,6 @@ def test_model(

cache_budget = int(self.configs.approach.split("-")[-1])
past_key_values = RaaSCache(page_size=16, cache_budget=cache_budget)

with torch.no_grad():

# Prefill
Expand Down Expand Up @@ -403,6 +427,8 @@ def test_model(
JCT = prefill_time + np.sum(decode_time)
TPOT = np.sum(decode_time) / num_decode

if "optimized" in self.configs.approach:
pipe.model.reset_model()
model_output = pipe.tokenizer.decode(generated_content, skip_special_tokens=True)
return model_output, TTFT, JCT, TPOT, num_decode

Expand Down
Loading