diff --git a/real-multi-round-qa/README.md b/real-multi-round-qa/README.md index 26ea55d..1b09ec0 100644 --- a/real-multi-round-qa/README.md +++ b/real-multi-round-qa/README.md @@ -2,70 +2,74 @@ ## Overview -This benchmark is designed to identify **the maximum harmonic mean of user sessions $(C,S)$ that can be kept active while maintaining a steady-state TTFT ≤ 2 s (95-th percentile)**. By sweeping the concurrency ($C$) and sequential ($S$) independently, it isolates whether compute capacity or KV-cache pressure is the first limiting factor. +This benchmark is designed to explore how TTFT changes across different $(C, S)$ combinations by sweeping concurrency ($C$) and session depth ($S$) independently. This helps isolate whether compute capacity or KV-cache pressure is the primary limiting factor. - -We highly recommend monitoring vLLM/LMCache/GPU/storage metrics at the same time. +We highly recommend monitoring vLLM/LMCache/GPU/storage metrics at the same time. The JSON output from the benchmark includes metrics from vLLM/LMCache. This benchmark feeds full‑length novels to your LLM server and asks many follow‑up questions, just like a book critic. It is handy for testing long‑context handling and KV‑cache tools such as LMCache. -The benchmark is called CxS (pronounced six for simplicity), referring to the product of Concurrent $\times$ Sequential users. - -### Definition - -Let us define the set of candidate pairs: - -$$ -\mathcal{D} = \{ (C_i, S_i) \mid \mathrm{TTFT}_{95}^{(i)} \leq 2 \} -$$ - -### Objective - -More precisely, we aim to find the pair that maximizes the harmonic mean among all candidates in $\mathcal{D}$: - - -$$ -\underset{(C_i, S_i) \in \mathcal{D}}{\arg\max} \left( \frac{2 C_i S_i}{C_i + S_i} \right) -$$ - -We use the harmonic mean to compare scores. -As a business metric, we report the product, CxS. -For example, we say "Our system can keep up to {C×S} user sessions active!" +The benchmark is called CxS (pronounced six for simplicity), referring to the product of Concurrent $\times$ Session Depth. ## Two simple knobs | Option | What it means | | ---- | ---- | -| `--num-users-concurrent` (C) | How many threads run in parallel. | -| `--num-users-sequential` (S) | How many users each thread serves in turn. | +| `--concurrent` (C) | How many threads run in parallel. | +| `--session-depth` (S) | How many sessions each thread serves in turn. | You can: -* raise concurrent to test compute-side capability (higher GPU utilization; total KV footprint also rises). -* raise sequential to test KV-cache pressure (larger resident KV per GPU, little change in instantaneous GPU utilization). +* raise $C$ to test compute-side capability (higher GPU utilization; total KV footprint also rises). +* raise $S$ to test KV-cache pressure (larger resident KV per GPU, little change in instantaneous GPU utilization). ## Execution model ``` -Concurrent USER: {A,B} -Sequential USER: {X,Y} -All USER: {AX,AY,BX,BY} +Concurrent: {A,B} +Session Depth: {X,Y} +All Session: {AX,AY,BX,BY} Timeline ------------------------------------------------- Thread A: - Turn 0 → UserAX: Q1 "Read and summarize this novel. {AX novel contents}" → Get Response - Turn 0 → UserAY: Q1 "Read and summarize this novel. {AY novel contents}" → Get Response - Turn 1 → UserAX: Q2 "Write down the author's feelings." → Get Response - Turn 1 → UserAY: Q2 "Write down the author's feelings." → Get Response + Turn 0 → SessionAX: Q1 "Read and summarize this novel. {AX novel contents}" → Get Response + Turn 0 → SessionAY: Q1 "Read and summarize this novel. {AY novel contents}" → Get Response + Turn 1 → SessionAX: Q2 "Write down the author's feelings." → Get Response + Turn 1 → SessionAY: Q2 "Write down the author's feelings." → Get Response ... Thread B: - Turn 0 → UserBX: Q1 "Read and summarize this novel. {BX novel contents}" → Get Response - Turn 0 → UserBY: Q1 "Read and summarize this novel. {BY novel contents}" → Get Response - Turn 1 → UserBX: Q2 "Write down the author's feelings." → Get Response - Turn 1 → UserBY: Q2 "Write down the author's feelings." → Get Response + Turn 0 → SessionBX: Q1 "Read and summarize this novel. {BX novel contents}" → Get Response + Turn 0 → SessionBY: Q1 "Read and summarize this novel. {BY novel contents}" → Get Response + Turn 1 → SessionBX: Q2 "Write down the author's feelings." → Get Response + Turn 1 → SessionBY: Q2 "Write down the author's feelings." → Get Response ... ``` +## For system competition + +The CxS benchmark provides a scalar score to encourage healthy competition, but its use is not mandatory. + +### Definition + +Let us define the set of candidate pairs: + +$$ +\mathcal{D} = {\{ (C_i, S_i) \mid \mathrm{TTFT}_{95}^{(i)} \leq 2 \}} +$$ + +### Objective + +More precisely, we aim to find the pair that maximizes the harmonic mean among all candidates in $\mathcal{D}$: + + +$$ +\underset{(C_i, S_i) \in \mathcal{D}}{\arg\max} \left( \frac{2 C_i S_i}{C_i + S_i} \right) +$$ + +## For business metric + +As a business metric, we report the product, CxS. +For example, we say "Our system can keep up to {C×S} user sessions active!" + ## Getting Started ```bash @@ -75,9 +79,9 @@ python prepare.py --output data --model Qwen/Qwen2.5-7B-Instruct-1M # Models use ```bash # Run the benchmark many times -BASE_URL="http://localhost:8000/v1" +BASE_URL="http://localhost:8000" MODEL="Qwen/Qwen2.5-7B-Instruct-1M" -NUM_ROUNDS=3 +NUM_ROUNDS=12 OUTPUT_DIR="bench_dir" SRC_DIR="./data/128k" mkdir -p "$OUTPUT_DIR" @@ -86,67 +90,107 @@ for c in {1..4}; do # You can change c and s to any value you like. for s in {1..4}; do TIMESTAMP=$(date +%s) OUTPUT_FILE="${OUTPUT_DIR}/bench_c${c}_s${s}_${TIMESTAMP}.json" - echo "Running benchmark: concurrent=${c}, sequential=${s}" - python multi-round-qa.py --num-users-concurrent "$c" --num-users-sequential "$s" --num-rounds "$NUM_ROUNDS" --model "$MODEL" --base-url "$BASE_URL" --output "$OUTPUT_FILE" --src-dir "$SRC_DIR" + echo "Running benchmark: C=${c}, S=${s}" + python multi-round-qa.py -c "$c" -s "$s" --num-rounds "$NUM_ROUNDS" --model "$MODEL" --base-url "$BASE_URL" --output "$OUTPUT_FILE" --src-dir "$SRC_DIR" done done ``` +We compare two systems for demo: + +System A +* Model + * Qwen/Qwen2.5-7B-Instruct-1M +* Dataset + * 32k +* CPU/GPU + * NVIDIA GH200 480GB +* vLLM + * v0.9.0.1 + * enable prefix-caching + * enable chunked prefill +* LMCache + * local_cpu: True + * max_local_cpu_size: 200 + * pipelined_backend: True + * save_decode_cache: True + +System B +* Model + * Qwen/Qwen2.5-7B-Instruct-1M +* Dataset + * 32k +* CPU/GPU + * NVIDIA GH200 480GB +* vLLM + * v0.9.0.1 + * enable prefix-caching + * enable chunked prefill +* LMCache + * local_cpu: True + * max_local_cpu_size: 200 + * pipelined_backend: True + * save_decode_cache: True + * local_disk: file:///data/tmp + * max_local_disk_size: 400 +* Storage + * DDN EXAScaler 2.14.0 + * stripe count is 8 + * stripe size is 1MiB + ```bash # Plot and Show Result -$ python plot.py ./bench_dir_vllm vllm.png - num_users_concurrent num_users_sequential ttft_95 -0 4 2 0.498404 -1 4 4 33.565437 -2 4 3 0.794144 -3 1 4 0.311046 -4 2 2 0.406148 -5 2 4 0.459704 -6 2 3 0.326396 -7 1 2 0.411317 -8 3 3 0.378674 -9 2 1 0.445499 -10 3 4 42.531053 -11 1 3 0.455651 -12 4 1 0.504505 -13 3 2 0.393902 -14 3 1 0.364927 -15 1 1 0.379049 -Max harmonic mean (C,S) where TTFT_95 <= 2s: 3.43 - => C=4.0, S=3.0, CxS=12.0 -$ python plot.py ./bench_dir_lmcache lmcache.png - num_users_concurrent num_users_sequential ttft_95 -0 1 1 0.524989 -1 3 2 0.592148 -2 4 4 1.202544 -3 3 4 1.286755 -4 2 1 0.477370 -5 3 3 0.586793 -6 2 3 0.627655 -7 4 1 0.575724 -8 4 3 1.251918 -9 2 4 0.446477 -10 1 4 0.460711 -11 3 1 0.495073 -12 1 3 0.329389 -13 4 2 0.586223 -14 1 2 0.477946 -15 2 2 0.457463 -Max harmonic mean (C,S) where TTFT_95 <= 2s: 4.00 - => C=4.0, S=4.0, CxS=16.0 +$ python plot.py lmcache_bench_dir-1749973344 lmcache_with_cpu_200g.png + c s ttft_95 +0 8 16 2.674693 +1 12 32 3.268448 +2 4 32 2.496206 +3 16 16 3.310291 +4 4 8 0.146159 +5 8 32 2.801732 +6 12 24 3.283783 +7 12 16 3.185047 +8 12 8 0.390896 +9 4 24 0.217809 +10 16 8 3.799740 +11 8 8 0.347083 +12 16 24 3.171192 +13 16 32 3.032414 +14 8 24 3.383691 +15 4 16 0.253737 +Best (C,S) with TTFT_95 ≤ 2 s → C=12.0, S=8.0, HarmonicMean=9.60, C×S=96.0 +Saved: lmcache_with_cpu_200g.png +$ python plot.py lmcache_bench_dir-1749897431 lmcache_with_cpu_200g_exa_400g.png + c s ttft_95 +0 4 16 0.255378 +1 8 24 3.213307 +2 16 24 4.067904 +3 4 24 0.612876 +4 8 32 4.389398 +5 4 8 0.158686 +6 12 24 3.939205 +7 12 8 0.634048 +8 4 32 1.191106 +9 12 32 3.475115 +10 16 16 3.156051 +11 8 8 0.264291 +12 12 16 2.739532 +13 16 32 3.853057 +14 8 16 1.424959 +15 16 8 3.470811 +Best (C,S) with TTFT_95 ≤ 2 s → C=8.0, S=16.0, HarmonicMean=10.67, C×S=128.0 +Saved: lmcache_with_cpu_200g_exa_400g.png ``` - -LMCache allows 1.17x increase in the number of user sessions kept active at least. - -Note: LMCache has not yet reached its limit in this case, -so we can aim to further improve the score by changing C and S. +This result shows that adding external storage (DDN EXAScaler) as a tier in the KV cache can increase the number of active sessions. ## Viz -vllm.png +The white dashed line indicates the TTFT = 2s boundary. + +System A result: -![vLLM Plot](vllm.png) +![LMCache+CPU Plot](lmcache_with_cpu_200g.png) -lmcache.png +System B result: -![LMCache Plot](lmcache.png) +![LMCache+CPU+Storage400g Plot](lmcache_with_cpu_200g_exa_400g.png) diff --git a/real-multi-round-qa/lmcache.png b/real-multi-round-qa/lmcache.png deleted file mode 100644 index a8cbd1b..0000000 Binary files a/real-multi-round-qa/lmcache.png and /dev/null differ diff --git a/real-multi-round-qa/lmcache_with_cpu_200g.png b/real-multi-round-qa/lmcache_with_cpu_200g.png new file mode 100644 index 0000000..59bba5c Binary files /dev/null and b/real-multi-round-qa/lmcache_with_cpu_200g.png differ diff --git a/real-multi-round-qa/lmcache_with_cpu_200g_exa_400g.png b/real-multi-round-qa/lmcache_with_cpu_200g_exa_400g.png new file mode 100644 index 0000000..fc2b344 Binary files /dev/null and b/real-multi-round-qa/lmcache_with_cpu_200g_exa_400g.png differ diff --git a/real-multi-round-qa/multi-round-qa.py b/real-multi-round-qa/multi-round-qa.py index 1e384a5..b8478ef 100644 --- a/real-multi-round-qa/multi-round-qa.py +++ b/real-multi-round-qa/multi-round-qa.py @@ -8,6 +8,7 @@ from dataclasses import dataclass, asdict from typing import List import json +import requests FIRST_PROMPT = "Read and summarize this novel.\n\n{}" FOLLOWUP_PROMPTS = [ @@ -57,11 +58,13 @@ class Result: session_id: str turn: int + start_time: float latency: float ttft: float generation_time: float prompt_tokens: int completion_tokens: int + metrics: str status: str class ChatSession: @@ -98,7 +101,7 @@ def append_assistant_message(self, content): self.messages.append({"role": "assistant", "content": content}) self.turns += 1 -async def run_turn(session: ChatSession, client: openai.AsyncOpenAI) -> Result: +async def run_turn(session: ChatSession, client: openai.AsyncOpenAI, base_url: str) -> Result: prompt = session.get_next_prompt() session.append_user_message(prompt) @@ -109,6 +112,10 @@ async def run_turn(session: ChatSession, client: openai.AsyncOpenAI) -> Result: prompt_tokens = 0 print(f"Session {session.session_id}, Turn {session.turns}: {prompt[:50]}...") + + resp = requests.get(f"{base_url}/metrics") + resp.raise_for_status() + response = await client.chat.completions.create( model=session.model, messages=session.messages, @@ -137,41 +144,43 @@ async def run_turn(session: ChatSession, client: openai.AsyncOpenAI) -> Result: result = Result( session_id=session.session_id, turn=session.turns, + start_time=start_time, latency=latency, ttft=ttft, generation_time=generation_time, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, + metrics=resp.text, status="success", ) - + session.append_assistant_message(content) return result async def run_group(args) -> List[Result]: - client = openai.AsyncOpenAI(base_url=args.base_url, api_key="EMPTY") - sessions = [ChatSession(args) for _ in range(args.num_users_sequential)] + client = openai.AsyncOpenAI(base_url=f"{args.base_url}/v1", api_key="EMPTY") + sessions = [ChatSession(args) for _ in range(args.session_depth)] results = [] while any(not s.is_finished() for s in sessions): for session in sessions: if session.is_finished(): continue - result = await run_turn(session, client) + result = await run_turn(session, client, args.base_url) results.append(result) return results async def run_all_concurrent(args): - tasks = [run_group(args) for _ in range(args.num_users_concurrent)] + tasks = [run_group(args) for _ in range(args.concurrent)] all_results = await asyncio.gather(*tasks) return [asdict(r) for group in all_results for r in group] def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--num-users-concurrent", type=int, required=True) - parser.add_argument("--num-users-sequential", type=int, required=True) + parser.add_argument("-c", "--concurrent", type=int, required=True) + parser.add_argument("-s", "--session-depth", type=int, required=True) parser.add_argument("--model", type=str, required=True) parser.add_argument("--base-url", type=str, required=True) parser.add_argument("--num-rounds", type=int, default=10) diff --git a/real-multi-round-qa/plot.py b/real-multi-round-qa/plot.py index f3a5c70..e083c53 100644 --- a/real-multi-round-qa/plot.py +++ b/real-multi-round-qa/plot.py @@ -1,89 +1,108 @@ +#!/usr/bin/env python3 import os -import json import glob +import json import argparse -import pandas as pd import numpy as np +import pandas as pd import matplotlib.pyplot as plt +from matplotlib.colors import LogNorm +from matplotlib import ticker def main(): - parser = argparse.ArgumentParser(description="Analyze number of user sessions from benchmark results.") + parser = argparse.ArgumentParser( + description="Visualize TTFT_95 across (C,S) without interpolation." + ) parser.add_argument("input_dir", help="Directory containing JSON files") - parser.add_argument("output", help="Output path for the 3D bar plot image") + parser.add_argument("output", help="Output image path") args = parser.parse_args() - json_files = glob.glob(f"{args.input_dir}/*.json") - all_params = [] summary_records = [] + fixed_params_seen = [] - for file in json_files: - with open(file, 'r') as f: + for path in glob.glob(os.path.join(args.input_dir, "*.json")): + with open(path) as f: data = json.load(f) - if "params" not in data or "results" not in data: - print(f"Skipping {file}: missing 'params' or 'results'") - continue - params = data["params"] - results = data["results"] - - params_fixed = {k: v for k, v in params.items() - if k not in ["num_users_concurrent", "num_users_sequential", "output"]} - all_params.append(params_fixed) - - df = pd.DataFrame(results) - df = df[df["turn"] != 0] - if df.empty: - continue - - ttft_95 = df["ttft"].quantile(0.95) - summary_records.append({ - "num_users_concurrent": params["num_users_concurrent"], - "num_users_sequential": params["num_users_sequential"], - "ttft_95": ttft_95 - }) - - if all_params: - first_params = all_params[0] - assert all(p == first_params for p in all_params), "Inconsistent fixed parameters" - - summary_df = pd.DataFrame(summary_records) - print(summary_df) - - if summary_df.empty: - print("No valid TTFT data to visualize.") + if "params" not in data or "results" not in data: + print(f"Skip {path}: missing section") + continue + + p = data["params"] + res = pd.DataFrame(data["results"]) + res = res[res["turn"] != 0] # pre-fill を除外 + if res.empty: + continue + + summary_records.append({ + "c" : p["concurrent"], + "s" : p["session_depth"], + "ttft_95" : res["ttft"].quantile(0.95), + }) + fixed_params_seen.append({k:v for k,v in p.items() + if k not in ("concurrent","session_depth","output")}) + + if not summary_records: + print("No valid data") return - - summary_df_sorted = summary_df.sort_values(by=['num_users_concurrent', 'num_users_sequential']) - - fig = plt.figure() - ax = fig.add_subplot(111, projection='3d') - - x = summary_df_sorted['num_users_concurrent'] - y = summary_df_sorted['num_users_sequential'] - z = np.zeros_like(x) - dz = summary_df_sorted['ttft_95'] - - colors = ['blue' if v <= 2 else 'red' for v in dz] - - ax.bar3d(x, y, z, dx=0.5, dy=0.5, dz=dz, color=colors, shade=True) - ax.set_xlabel('Concurrent Users (C)') - ax.set_ylabel('Sequential Users (S)') - ax.set_zlabel('95% Tail TTFT (s)') - plt.title('TTFT 95% Tail vs Concurrent/Sequential Users') - ax.invert_xaxis() - plt.savefig(args.output) - - # Max harmonic mean under 2s TTFT - summary_under_2s = summary_df[summary_df["ttft_95"] <= 2].copy() - if not summary_under_2s.empty: - summary_under_2s["harmonic_mean"] = 2 * summary_under_2s["num_users_concurrent"] * summary_under_2s["num_users_sequential"] / ( - summary_under_2s["num_users_concurrent"] + summary_under_2s["num_users_sequential"] - ) - best_row = summary_under_2s.sort_values("harmonic_mean", ascending=False).iloc[0] - product = best_row["num_users_concurrent"] * best_row["num_users_sequential"] - print(f"Max harmonic mean (C,S) where TTFT_95 <= 2s: {best_row['harmonic_mean']:.2f}") - print(f" => C={best_row['num_users_concurrent']}, S={best_row['num_users_sequential']}, CxS={product}") - else: - print("No data points with TTFT_95 <= 2s.") + if any(fp != fixed_params_seen[0] for fp in fixed_params_seen): + raise ValueError("Inconsistent fixed parameters across files") + + df = pd.DataFrame(summary_records) + print(df) + + grid = df.pivot(index="s", columns="c", values="ttft_95").sort_index(ascending=True) + S_vals = grid.index.values + C_vals = grid.columns.values + Z = np.ma.masked_invalid(grid.values) + + fig, ax = plt.subplots(figsize=(9, 7)) + + pcm = ax.pcolormesh( + C_vals, + S_vals, + Z, + shading="nearest", + cmap="plasma", + norm=LogNorm(vmin=0.01, vmax=100), + ) + + cbar = fig.colorbar(pcm, ax=ax) + cbar.set_label("TTFT_95 (s)") + cbar.set_ticks([0.01, 0.1, 1, 2, 4, 8, 16, 32, 64, 100]) + cbar.ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.2g")) + + C_mesh, S_mesh = np.meshgrid(C_vals, S_vals) + ax.contour( + C_mesh, S_mesh, Z, + levels=[2.0], + colors="white", + linewidths=2, + linestyles="dashed", + ) + + ax.scatter(df["c"], df["s"], s=40, c="black", marker="o", label="measured") + + ax.set_xlabel("Concurrent (C)") + ax.set_ylabel("Session Depth (S)") + ax.set_title("TTFT_95 Heatmap across (C, S) — no interpolation") + ax.set_xticks(C_vals) + ax.set_yticks(S_vals) + ax.grid(True, which="both", linestyle="--", alpha=0.3) + ax.legend(loc="upper right") + + ok = df[df["ttft_95"] <= 2.0].copy() + if not ok.empty: + ok["hmean"] = 2 * ok["c"] * ok["s"] / (ok["c"] + ok["s"]) + best = ok.loc[ok["hmean"].idxmax()] + ax.scatter(best["c"], best["s"], + s=160, c="cyan", edgecolors="black", marker="*", label="Best (C,S)") + print(f"Best (C,S) with TTFT_95 ≤ 2 s → C={best.c}, S={best.s}, " + f"HarmonicMean={best.hmean:.2f}, C×S={best.c*best.s}") + ax.legend(loc="upper right") + + fig.tight_layout() + fig.savefig(args.output) + print(f"Saved: {args.output}") if __name__ == "__main__": main() diff --git a/real-multi-round-qa/vllm.png b/real-multi-round-qa/vllm.png deleted file mode 100644 index 067e442..0000000 Binary files a/real-multi-round-qa/vllm.png and /dev/null differ