From 2475c361e0f97a5a7316dc5e946a095b1bd54822 Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Fri, 10 Oct 2025 10:41:49 +0800 Subject: [PATCH 1/8] mha decode add performance test --- tests/test_mha_decode.py | 3 ++- top/kernel/mha.py | 23 +++++++++++++---------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/tests/test_mha_decode.py b/tests/test_mha_decode.py index ad6d09a2..c59ec142 100644 --- a/tests/test_mha_decode.py +++ b/tests/test_mha_decode.py @@ -11,11 +11,12 @@ def test_mha_decode_kernel(B, S, H, D, tune): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=1, help='batch size') + parser.add_argument('--seqlen_q', type=int, default=1, help='sequence length') parser.add_argument('--seqlen_kv', type=int, default=8192, help='sequence length') parser.add_argument('--heads', type=int, default=32, help='num heads') parser.add_argument('--dim', type=int, default=128, help='head dim') parser.add_argument('--tune', action='store_true', default=True, help='tune the kernel') args = parser.parse_args() - B, S, H, D, tune = args.batch, args.seqlen_kv, args.heads, args.dim, args.tune + B, S, S_q, H, D, tune = args.batch, args.seqlen_kv, args.seqlen_q, args.heads, args.dim, args.tune test_mha_decode_kernel(B, S, H, D, tune) diff --git a/top/kernel/mha.py b/top/kernel/mha.py index ffa10f60..8d8842e5 100644 --- a/top/kernel/mha.py +++ b/top/kernel/mha.py @@ -786,10 +786,11 @@ class _MHA_decode_attention(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, num_split, config): BATCH, KV_CTX, H, D_HEAD = k.shape + Q_CTX = q.shape[1] - mod = _mha_decode(BATCH, H, 1, KV_CTX, D_HEAD)(**config) - glse = torch.empty((BATCH, H, num_split, 1), dtype=q.dtype, device=q.device) - Output_partial = torch.empty((BATCH, 1, H, num_split, D_HEAD), + mod = _mha_decode(BATCH, H, Q_CTX, KV_CTX, D_HEAD)(**config) + glse = torch.empty((BATCH, H, num_split, Q_CTX), dtype=q.dtype, device=q.device) + Output_partial = torch.empty((BATCH, Q_CTX, H, num_split, D_HEAD), dtype=q.dtype, device=q.device) return mod(q, k, v, glse, Output_partial) @@ -809,6 +810,7 @@ def __init__(self, num_heads, seqlen_kv, head_dim, + seqlen_q=1, threads=None, block_M=None, block_N=None, @@ -821,6 +823,7 @@ def __init__(self, self.batch_size = batch_size self.num_heads = num_heads self.seqlen_kv = seqlen_kv + self.seqlen_q = seqlen_q self.head_dim = head_dim block_M_ = 64 block_N_ = 64 if head_dim <= 128 else 32 @@ -839,15 +842,15 @@ def __init__(self, } self.tune = tune self.tune_config = None - self.program = _mha_decode(self.batch_size, self.num_heads, 1, self.seqlen_kv, + self.program = _mha_decode(self.batch_size, self.num_heads, self.seqlen_q, self.seqlen_kv, self.head_dim)(**self.config) # self.kernel = tilelang.compile(self.program, out_idx=[5]) self.profiler = self.program.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) - flops_per_matmul = 2.0 * batch_size * num_heads * seqlen_kv * head_dim + flops_per_matmul = 2.0 * batch_size * num_heads * seqlen_kv * head_dim * self.seqlen_q self.total_flops = 2 * flops_per_matmul def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor: - assert Q.dim() == 4 and Q.size(1) == 1, "Q must have shape (bsz, 1, H, D)" + assert Q.dim() == 4 and Q.size(1) == self.seqlen_q, "Q must have shape (bsz, S_q, H, D)" if self.tune_config is None and self.tune: self.autotune() config = self.tune_config if self.tune_config else self.config @@ -856,7 +859,7 @@ def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Te def autotune(self): best_result = _mha_decode( - self.batch_size, self.num_heads, 1, self.seqlen_kv, self.head_dim, tune=True) + self.batch_size, self.num_heads, self.seqlen_q, self.seqlen_kv, self.head_dim, tune=True) best_latency = best_result.latency best_config = best_result.config ref_latency = best_result.ref_latency @@ -876,7 +879,7 @@ def ref_program(cls, V: torch.Tensor, glse: torch.Tensor = None, Output_partial: torch.Tensor = None) -> torch.Tensor: - assert Q.dim() == 4 and Q.size(1) == 1, "Q must have shape (bsz, 1, H, D)" + assert Q.dim() == 4 and Q.size(1) == self.seqlen_q, "Q must have shape (bsz, 1, H, D)" dim = Q.size(-1) scores = torch.einsum('bqhd,bkhd->bqhk', Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) @@ -885,7 +888,7 @@ def ref_program(cls, return output def gen_inputs(self): - shape_q = self.batch_size, 1, self.num_heads, self.head_dim + shape_q = self.batch_size, self.seqlen_q, self.num_heads, self.head_dim shape_kv = self.batch_size, self.seqlen_kv, self.num_heads, self.head_dim Q = torch.randn(shape_q, dtype=self.dtype, device=self.device) K = torch.randn(shape_kv, dtype=self.dtype, device=self.device) @@ -905,7 +908,7 @@ def profile(self, warmup=500): if self.tune_config is None and self.tune: self.autotune() if self.tune_config: - self.program = _mha_decode(self.batch_size, self.num_heads, 1, self.seqlen_kv, + self.program = _mha_decode(self.batch_size, self.num_heads, self.seqlen_q, self.seqlen_kv, self.head_dim)(**self.tune_config) # self.kernel = tilelang.compile(self.program, out_idx=[5]) self.profiler = self.program.get_profiler( From 3c3357cc6b8d33876a7a8a8ca1a7a453719a93cb Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Fri, 10 Oct 2025 14:04:35 +0800 Subject: [PATCH 2/8] support performance test --- tests/test_mha_decode.py | 6 +++--- top/kernel/mha.py | 2 -- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/test_mha_decode.py b/tests/test_mha_decode.py index c59ec142..272851d0 100644 --- a/tests/test_mha_decode.py +++ b/tests/test_mha_decode.py @@ -2,8 +2,8 @@ from top import MHADecodeKernel -def test_mha_decode_kernel(B, S, H, D, tune): - kernel = MHADecodeKernel(B, H, S, D, tune=tune) +def test_mha_decode_kernel(B, S, H, D, S_q, tune): + kernel = MHADecodeKernel(B, H, S, D, seqlen_q=S_q, tune=tune) kernel.check() kernel.profile() @@ -19,4 +19,4 @@ def test_mha_decode_kernel(B, S, H, D, tune): args = parser.parse_args() B, S, S_q, H, D, tune = args.batch, args.seqlen_kv, args.seqlen_q, args.heads, args.dim, args.tune - test_mha_decode_kernel(B, S, H, D, tune) + test_mha_decode_kernel(B, S, H, D, S_q, tune) diff --git a/top/kernel/mha.py b/top/kernel/mha.py index 8d8842e5..98e31a17 100644 --- a/top/kernel/mha.py +++ b/top/kernel/mha.py @@ -850,7 +850,6 @@ def __init__(self, self.total_flops = 2 * flops_per_matmul def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor: - assert Q.dim() == 4 and Q.size(1) == self.seqlen_q, "Q must have shape (bsz, S_q, H, D)" if self.tune_config is None and self.tune: self.autotune() config = self.tune_config if self.tune_config else self.config @@ -879,7 +878,6 @@ def ref_program(cls, V: torch.Tensor, glse: torch.Tensor = None, Output_partial: torch.Tensor = None) -> torch.Tensor: - assert Q.dim() == 4 and Q.size(1) == self.seqlen_q, "Q must have shape (bsz, 1, H, D)" dim = Q.size(-1) scores = torch.einsum('bqhd,bkhd->bqhk', Q, K) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) From d6e8cd75a9cb815a561858c5f5a54c3e40403425 Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Fri, 10 Oct 2025 14:27:16 +0800 Subject: [PATCH 3/8] add fa op profile --- profile/common_tools.py | 432 +++++++++++++++++++++ profile/input_params/mha_decode_params.csv | 3 + profile/input_params/mha_params.csv | 3 + profile/mha_decode_profile_test.py | 90 +++++ profile/mha_profile_test.py | 90 +++++ pyproject.toml | 2 +- 6 files changed, 619 insertions(+), 1 deletion(-) create mode 100644 profile/common_tools.py create mode 100644 profile/input_params/mha_decode_params.csv create mode 100644 profile/input_params/mha_params.csv create mode 100644 profile/mha_decode_profile_test.py create mode 100644 profile/mha_profile_test.py diff --git a/profile/common_tools.py b/profile/common_tools.py new file mode 100644 index 00000000..367b9277 --- /dev/null +++ b/profile/common_tools.py @@ -0,0 +1,432 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +"""Common performance tools for FA operator sweep runners. + +This module centralizes: +- CSV/TSV autodetection and robust reading/writing +- Header normalization + alias mapping +- Typed cell parsing (int/bool) +- Logging setup +- Subprocess execution with full stdout/stderr logging +- Metrics parsing from test scripts stdout +- A generic `run_sweep(...)` to remove duplication in per-op scripts +- Pretty table printing to stdout with a switch +""" + +from __future__ import annotations + +import csv +import io +import logging +import pathlib +import re +import subprocess +import sys +from typing import Callable, Dict, List, Optional, Tuple + + +# -------------------------- +# Regex for metrics +# -------------------------- +_FLOAT = r"([0-9]+(?:\.[0-9]+)?(?:[eE][+-]?[0-9]+)?)" + +# --- Strict, line-anchored patterns to avoid matching lines like "Best fwd latency" / "Best TFlops" --- +REF_LAT_RE = re.compile(rf"(?m)^\s*Reference\s+Latency:\s*{_FLOAT}\s*ms\b") +REF_TFLOPS_RE = re.compile(rf"(?m)^\s*Reference\s+FLOPs?:\s*{_FLOAT}\s*T[Ff][Ll][Oo][Pp][s]?\b") +LAT_RE = re.compile(rf"(?m)^\s*Latency:\s*{_FLOAT}\s*ms\b") +TFLOPS_RE = re.compile(rf"(?m)^\s*FLOPs?:\s*{_FLOAT}\s*T[Ff][Ll][Oo][Pp][s]?\b") + +def parse_stdout_metrics(stdout: str) -> Dict[str, str]: + """Parse final Reference Latency/FLOPs and Latency/FLOPs from stdout. + + Notes: + * Patterns are anchored at the start of the line to avoid accidental + matches such as "Best fwd latency" or "Best TFlops". + * If multiple matching lines appear, the last occurrence is taken as + the final result. + """ + def _last_float(pattern: re.Pattern[str]) -> Optional[float]: + matches = list(pattern.finditer(stdout)) + if not matches: + return None + return float(matches[-1].group(1)) + + ref_lat = _last_float(REF_LAT_RE) + ref_tf = _last_float(REF_TFLOPS_RE) + lat = _last_float(LAT_RE) + tf = _last_float(TFLOPS_RE) + + return { + "ref_latency_ms": f"{ref_lat:.2f}" if ref_lat is not None else "", + "ref_tflops": f"{ref_tf:.2f}" if ref_tf is not None else "", + "latency_ms": f"{lat:.2f}" if lat is not None else "", + "tflops": f"{tf:.2f}" if tf is not None else "", + } + + +# -------------------------- +# Logging +# -------------------------- +def setup_logger(log_path: pathlib.Path) -> logging.Logger: + """Create a fresh logger; overwrites the log file.""" + logger = logging.getLogger("fa_perf_sweep") + logger.setLevel(logging.INFO) + logger.handlers.clear() + + fmt = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s") + + # Open in write mode to avoid appending to old logs. + fh = logging.FileHandler(log_path, mode="w", encoding="utf-8") + fh.setLevel(logging.INFO) + fh.setFormatter(fmt) + logger.addHandler(fh) + + sh = logging.StreamHandler(sys.stdout) + sh.setLevel(logging.INFO) + sh.setFormatter(fmt) + logger.addHandler(sh) + + return logger + + +# -------------------------- +# CSV/TSV utilities +# -------------------------- +def _detect_dialect(sample: str) -> csv.Dialect: + """Detect CSV dialect (comma/tab/semicolon/pipe).""" + try: + return csv.Sniffer().sniff(sample, delimiters=[",", "\t", ";", "|"]) + except Exception: + if sample.count("\t") >= max(sample.count(","), sample.count(";"), sample.count("|")): + class _TSV(csv.Dialect): + delimiter = "\t" + quotechar = '"' + doublequote = True + escapechar = None + lineterminator = "\n" + quoting = csv.QUOTE_MINIMAL + skipinitialspace = False + return _TSV() + return csv.get_dialect("excel") + + +def read_rows_csv_any(path: pathlib.Path, logger: Optional[logging.Logger] = None) -> Tuple[List[Dict[str, str]], List[str]]: + """Read CSV/TSV with autodetected delimiter. Returns (rows, fieldnames).""" + raw_text = path.read_text(encoding="utf-8-sig") + sample = raw_text[:4096] + dialect = _detect_dialect(sample) + if logger: + logger.info("Detected delimiter: %r", getattr(dialect, "delimiter", ",")) + + reader = csv.DictReader(io.StringIO(raw_text), dialect=dialect) + fieldnames = reader.fieldnames or [] + if logger: + logger.info("Raw headers: %s", fieldnames) + rows = list(reader) + return rows, fieldnames + + +def _build_out_fieldnames(input_fieldnames: List[str], extra_cols: List[str]) -> List[str]: + """Compute final CSV header = input headers + extra result columns (dedup).""" + return input_fieldnames + [c for c in extra_cols if c not in input_fieldnames] + + +def write_results_csv( + out_path: pathlib.Path, + input_fieldnames: List[str], + results_rows: List[Dict[str, str]], + extra_cols: List[str], +) -> List[str]: + """Write output CSV preserving original columns plus extra result columns. + + Returns: + The final header used (out_fieldnames). + """ + out_fieldnames = _build_out_fieldnames(input_fieldnames, extra_cols) + out_path.parent.mkdir(parents=True, exist_ok=True) + with out_path.open("w", newline="", encoding="utf-8") as f_out: + writer = csv.DictWriter(f_out, fieldnames=out_fieldnames) + writer.writeheader() + for row in results_rows: + writer.writerow(row) + return out_fieldnames + + +# -------------------------- +# Header + value helpers +# -------------------------- +def norm_header(s: str) -> str: + """Normalize header names: strip, lower, unify separators.""" + s = s.strip().lower().replace("\ufeff", "") + s = s.replace(" ", "_").replace("-", "_") + return s + + +def build_header_map(fieldnames: List[str], alias_map: Dict[str, List[str]]) -> Dict[str, str]: + """Build mapping canonical -> actual column given aliases.""" + norm_to_actual = {norm_header(f): f for f in fieldnames if f is not None} + mapping: Dict[str, str] = {} + for canonical, cands in alias_map.items(): + for cand in cands: + key = norm_header(cand) + if key in norm_to_actual: + mapping[canonical] = norm_to_actual[key] + break + return mapping + + +def get_cell(row: Dict[str, str], header_map: Dict[str, str], canonical: str, default: Optional[str]) -> Optional[str]: + """Safe getter using canonical name with fallback default.""" + actual = header_map.get(canonical) + if actual is None: + return default + return row.get(actual, default) + + +def to_int(s: Optional[str], default: int) -> int: + """Robust int parsing with defaults.""" + if s is None: + return default + s = s.strip() + if s == "": + return default + try: + return int(float(s)) + except Exception: + return default + + +def to_bool(s: Optional[str], default: bool) -> bool: + """Truthiness parse for boolean flags.""" + if s is None: + return default + v = s.strip().lower() + if v in {"true", "1", "yes", "y"}: + return True + if v in {"false", "0", "no", "n"}: + return False + return default + + +# -------------------------- +# Subprocess + metrics +# -------------------------- +def run_and_log(cmd: List[str], logger: logging.Logger) -> Tuple[int, str, str]: + """Run subprocess, return (returncode, stdout, stderr), log full outputs.""" + logger.info("Command: %s", " ".join(cmd)) + try: + proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False) + except FileNotFoundError as e: + logger.exception("Python or script not found: %s", e) + return 127, "", str(e) + + stdout = proc.stdout or "" + stderr = proc.stderr or "" + + if stdout.strip(): + logger.info("===== STDOUT BEGIN =====\n%s\n===== STDOUT END =====", stdout.rstrip("\n")) + if stderr.strip(): + logger.warning("===== STDERR BEGIN =====\n%s\n===== STDERR END =====", stderr.rstrip("\n")) + + return proc.returncode, stdout, stderr + + +def parse_stdout_metrics(stdout: str) -> Dict[str, str]: + """Parse Reference Latency/FLOPs and Latency/FLOPs from stdout.""" + ref_lat = REF_LAT_RE.search(stdout) + ref_tf = REF_TFLOPS_RE.search(stdout) + lat = LAT_RE.search(stdout) + tf = TFLOPS_RE.search(stdout) + + return { + "ref_latency_ms": f"{float(ref_lat.group(1)):.2f}" if ref_lat else "", + "ref_tflops": f"{float(ref_tf.group(1)):.2f}" if ref_tf else "", + "latency_ms": f"{float(lat.group(1)):.2f}" if lat else "", + "tflops": f"{float(tf.group(1)):.2f}" if tf else "", + } + + +def tail_line(text: str) -> str: + """Return the last non-empty line of a multi-line string for concise notes.""" + lines = [ln for ln in text.strip().splitlines() if ln.strip()] + return lines[-1] if lines else "" + + +# -------------------------- +# Pretty table printing +# -------------------------- +def _truncate(s: str, limit: int) -> str: + if limit <= 3 or len(s) <= limit: + return s[:limit] + return s[: limit - 3] + "..." + + +def _compute_col_widths(headers: List[str], rows: List[Dict[str, str]], max_col_width: int) -> List[int]: + widths = [len(h) for h in headers] + for row in rows: + for i, h in enumerate(headers): + val = str(row.get(h, "")) + widths[i] = min(max(widths[i], len(val)), max_col_width) + return widths + + +def _draw_row(cells: List[str], widths: List[int]) -> str: + parts = [] + for c, w in zip(cells, widths): + parts.append(" " + c.ljust(w) + " ") + return "|" + "|".join(parts) + "|" + + +def _draw_sep(widths: List[int]) -> str: + parts = [] + for w in widths: + parts.append("-" * (w + 2)) + return "+" + "+".join(parts) + "+" + + +def print_table(out_fieldnames: List[str], results_rows: List[Dict[str, str]], max_col_width: int = 32) -> None: + """Print a pretty ASCII table of the results to stdout.""" + headers = out_fieldnames + # Prepare truncated rows to avoid super-wide lines. + trunc_rows: List[Dict[str, str]] = [] + for row in results_rows: + new_row = {} + for h in headers: + new_row[h] = _truncate(str(row.get(h, "")), max_col_width) + trunc_rows.append(new_row) + + widths = _compute_col_widths(headers, trunc_rows, max_col_width) + sep = _draw_sep(widths) + + print(sep) + print(_draw_row(headers, widths)) + print(sep) + for row in trunc_rows: + print(_draw_row([str(row.get(h, "")) for h in headers], widths)) + print(sep) + + +# -------------------------- +# Generic sweep driver +# -------------------------- +RowCmdBuilder = Callable[[Dict[str, str], Dict[str, str], str, pathlib.Path], Tuple[List[str], str]] + +def run_sweep( + *, + operator_name: str, + in_path: pathlib.Path, + out_path: pathlib.Path, + script_path: pathlib.Path, + python_bin: str, + log_path: pathlib.Path, + alias_map: Dict[str, List[str]], + row_cmd_builder: RowCmdBuilder, + # NEW: + print_table_enable: bool = True, + table_max_col_width: int = 32, + table_columns: Optional[List[str]] = None, +) -> None: + """Generic sweep runner to minimize duplication. + + After writing results, optionally prints a pretty ASCII table to stdout. + + Args: + print_table_enable: Whether to print the results table to stdout. + table_max_col_width: Max width of each column when printing. + table_columns: If provided, only print these columns (that exist). + """ + # If the log file exists, announce and remove it before creating the logger. + if log_path.exists(): + # stdout print so the user sees it even before logger is ready. + print(f"[INFO] Existing log found, removing: {log_path}") + try: + os.remove(log_path) + except Exception as e: + # If deletion fails, we still proceed; setup_logger(...) opens with mode='w'. + print(f"[WARN] Failed to remove existing log: {e}") + + logger = setup_logger(log_path) + logger.info("[%s] Input: %s", operator_name, in_path) + logger.info("[%s] Output: %s", operator_name, out_path) + logger.info("[%s] Script: %s", operator_name, script_path) + logger.info("[%s] Python: %s", operator_name, python_bin) + logger.info("[%s] Log: %s", operator_name, log_path) + + if not in_path.exists(): + logger.error("Input file not found: %s", in_path) + sys.exit(2) + if not script_path.exists(): + logger.error("Test script not found: %s", script_path) + sys.exit(2) + + rows, fieldnames = read_rows_csv_any(in_path, logger) + if not rows: + logger.error("No rows found in input.") + sys.exit(3) + + header_map = build_header_map(fieldnames, alias_map) + logger.info("Header mapping (canonical -> actual): %s", header_map) + + results_rows: List[Dict[str, str]] = [] + for idx, row in enumerate(rows, start=1): + try: + cmd, screen_msg = row_cmd_builder(row, header_map, python_bin, script_path) + except Exception as e: + logger.exception("Row %d: failed to build command: %s", idx, e) + out_row = dict(row) + out_row.update({ + "ref_latency_ms": "", "ref_tflops": "", + "latency_ms": "", "tflops": "", + "returncode": "-1", + "note": f"build_cmd_error={type(e).__name__}:{e}", + }) + results_rows.append(out_row) + continue + + print(screen_msg) + logger.info(screen_msg) + + rc, stdout, stderr = run_and_log(cmd, logger) + metrics = parse_stdout_metrics(stdout) + note = "" + if rc != 0: + note_parts = [f"rc={rc}"] + last = tail_line(stderr) + if last: + note_parts.append(f"stderr_tail={last}") + note = ";".join(note_parts) + + out_row = dict(row) + out_row.update(metrics) + out_row["returncode"] = str(rc) + out_row["note"] = note + results_rows.append(out_row) + + print(f"[DONE {idx}/{len(rows)}] latency_ms={metrics.get('latency_ms','')} TFLOPs={metrics.get('tflops','')}") + logger.info("[DONE %d/%d] latency_ms=%s TFLOPs=%s", + idx, len(rows), metrics.get("latency_ms", ""), metrics.get("tflops", "")) + + # Write to CSV file + out_fieldnames = write_results_csv( + out_path, + fieldnames, + results_rows, + extra_cols=["ref_latency_ms", "ref_tflops", "latency_ms", "tflops", "returncode", "note"], + ) + logger.info("[%s] All done. Results saved to: %s", operator_name, out_path) + + # Pretty table (optional) + if print_table_enable: + if table_columns: + # Filter to a subset if requested and exists. + cols = [c for c in table_columns if c in out_fieldnames] + if cols: + filtered_rows = [] + for r in results_rows: + filtered_rows.append({c: r.get(c, "") for c in cols}) + print_table(cols, filtered_rows, max_col_width=table_max_col_width) + else: + # Fall back to all if provided columns are invalid. + print_table(out_fieldnames, results_rows, max_col_width=table_max_col_width) + else: + print_table(out_fieldnames, results_rows, max_col_width=table_max_col_width) diff --git a/profile/input_params/mha_decode_params.csv b/profile/input_params/mha_decode_params.csv new file mode 100644 index 00000000..efce0d3b --- /dev/null +++ b/profile/input_params/mha_decode_params.csv @@ -0,0 +1,3 @@ +Provider OP bs head_num kv_heads seq_len kv_seq_len dim +LLAMA-70B FlashMHA 64 64 64 64 1024 128 +LLAMA-70B FlashMHA 64 64 64 64 2048 128 \ No newline at end of file diff --git a/profile/input_params/mha_params.csv b/profile/input_params/mha_params.csv new file mode 100644 index 00000000..39577e65 --- /dev/null +++ b/profile/input_params/mha_params.csv @@ -0,0 +1,3 @@ +Provider OP bs head_num kv_heads seq_len kv_seq_len dim causal +LLAMA-70B FlashMHA 64 64 64 1024 1024 128 TRUE +LLAMA-70B FlashMHA 64 64 64 2048 2048 128 TRUE \ No newline at end of file diff --git a/profile/mha_decode_profile_test.py b/profile/mha_decode_profile_test.py new file mode 100644 index 00000000..2104fc76 --- /dev/null +++ b/profile/mha_decode_profile_test.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +"""Sweep runner for MHA-Decode using common_tools.run_sweep.""" + +from __future__ import annotations + +import argparse +import pathlib +import sys +from typing import Dict, List, Tuple + +from common_tools import ( + RowCmdBuilder, + get_cell, + run_sweep, + to_bool, + to_int, +) + + +ALIASES_DECODE: Dict[str, List[str]] = { + "bs": ["bs", "batch", "b"], + "seqlen_q": ["seqlen_q", "q_seq_len", "q_len", "seq_len_q"], + "seqlen_kv": ["seqlen_kv", "kv_seq_len", "kvsl", "kv_len", "seq_len_kv"], + "head_num": ["head_num", "heads", "head", "nheads", "h"], + "dim": ["dim", "hd", "head_dim", "d", "embed_dim"], + "tune": ["tune", "autotune", "enable_tune"], +} + + +def build_cmd_decode(row: Dict[str, str], header_map: Dict[str, str], python_bin: str, script_path: pathlib.Path) -> Tuple[List[str], str]: + batch = to_int(get_cell(row, header_map, "bs", "1"), 1) + seqlen_q = to_int(get_cell(row, header_map, "seqlen_q", "64"), 64) + seqlen_kv = to_int(get_cell(row, header_map, "seqlen_kv", "8192"), 8192) + heads = to_int(get_cell(row, header_map, "head_num", "32"), 32) + dim = to_int(get_cell(row, header_map, "dim", "128"), 128) + tune = to_bool(get_cell(row, header_map, "tune", "True"), True) + + cmd = [ + python_bin, str(script_path), + "--batch", str(batch), + "--seqlen_q", str(seqlen_q), + "--seqlen_kv", str(seqlen_kv), + "--heads", str(heads), + "--dim", str(dim), + ] + if tune: + cmd.append("--tune") + + msg = (f"[RUN DECODE] batch={batch}, seqlen_q={seqlen_q}, " + f"seqlen_kv={seqlen_kv}, heads={heads}, dim={dim}, tune={tune}") + return cmd, msg + + +def main() -> None: + parser = argparse.ArgumentParser(description="Sweep MHA-Decode from CSV/TSV using common_tools.run_sweep.") + parser.add_argument("--input", required=True, help="Path to input CSV/TSV.") + parser.add_argument("--output", required=True, help="Path to output CSV.") + parser.add_argument("--script", default="test_mha_decode.py", help="Path to test_mha_decode.py.") + parser.add_argument("--python", default=sys.executable, help="Python interpreter.") + parser.add_argument("--log", default="mha_decode_sweep.log", help="Path to log file.") + + # 改为“默认打印表格”,提供关闭开关 + parser.add_argument("--no-print-table", action="store_true", + help="Disable pretty ASCII table printing (enabled by default).") + parser.add_argument("--table-max-col-width", type=int, default=32, + help="Max width for each printed column.") + parser.add_argument("--table-columns", type=str, default="", + help="Comma-separated subset of columns to print (optional).") + + args = parser.parse_args() + table_cols = [c.strip() for c in args.table_columns.split(",") if c.strip()] if args.table_columns else None + + run_sweep( + operator_name="DECODE", + in_path=pathlib.Path(args.input).expanduser().resolve(), + out_path=pathlib.Path(args.output).expanduser().resolve(), + script_path=pathlib.Path(args.script).expanduser().resolve(), + python_bin=args.python, + log_path=pathlib.Path(args.log).expanduser().resolve(), + alias_map=ALIASES_DECODE, + row_cmd_builder=build_cmd_decode, + print_table_enable=not args.no_print_table, + table_max_col_width=args.table_max_col_width, + table_columns=table_cols, + ) + + +if __name__ == "__main__": + main() diff --git a/profile/mha_profile_test.py b/profile/mha_profile_test.py new file mode 100644 index 00000000..c4e4c67f --- /dev/null +++ b/profile/mha_profile_test.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +"""Sweep runner for MHA using common_tools.run_sweep.""" + +from __future__ import annotations + +import argparse +import pathlib +import sys +from typing import Dict, List, Tuple + +from common_tools import ( + RowCmdBuilder, + get_cell, + run_sweep, + to_bool, + to_int, +) + + +ALIASES_MHA: Dict[str, List[str]] = { + "bs": ["bs", "batch", "b"], + "seq_len": ["seq_len", "sl", "seq", "sequence_len", "sequence_length"], + "head_num": ["head_num", "heads", "head", "nheads", "h"], + "dim": ["dim", "hd", "head_dim", "d", "embed_dim"], + "causal": ["causal", "is_causal", "mask_causal"], + "tune": ["tune", "autotune", "enable_tune"], +} + + +def build_cmd_mha(row: Dict[str, str], header_map: Dict[str, str], python_bin: str, script_path: pathlib.Path) -> Tuple[List[str], str]: + batch = to_int(get_cell(row, header_map, "bs", "8"), 8) + seq_len = to_int(get_cell(row, header_map, "seq_len", "1024"), 1024) + heads = to_int(get_cell(row, header_map, "head_num", "32"), 32) + dim = to_int(get_cell(row, header_map, "dim", "64"), 64) + causal = to_bool(get_cell(row, header_map, "causal", "False"), False) + tune = to_bool(get_cell(row, header_map, "tune", "True"), True) + + cmd = [ + python_bin, str(script_path), + "--batch", str(batch), + "--seq_len", str(seq_len), + "--heads", str(heads), + "--dim", str(dim), + ] + if tune: + cmd.append("--tune") + if causal: + cmd.append("--causal") + + msg = f"[RUN MHA] batch={batch}, seq_len={seq_len}, heads={heads}, dim={dim}, causal={causal}, tune={tune}" + return cmd, msg + + +def main() -> None: + parser = argparse.ArgumentParser(description="Sweep MHA from CSV/TSV using common_tools.run_sweep.") + parser.add_argument("--input", required=True, help="Path to input CSV/TSV.") + parser.add_argument("--output", required=True, help="Path to output CSV.") + parser.add_argument("--script", default="test_mha.py", help="Path to test_mha.py.") + parser.add_argument("--python", default=sys.executable, help="Python interpreter.") + parser.add_argument("--log", default="mha_sweep.log", help="Path to log file.") + + # 改为“默认打印表格”,提供关闭开关 + parser.add_argument("--no-print-table", action="store_true", + help="Disable pretty ASCII table printing (enabled by default).") + parser.add_argument("--table-max-col-width", type=int, default=32, + help="Max width for each printed column.") + parser.add_argument("--table-columns", type=str, default="", + help="Comma-separated subset of columns to print (optional).") + + args = parser.parse_args() + table_cols = [c.strip() for c in args.table_columns.split(",") if c.strip()] if args.table_columns else None + + run_sweep( + operator_name="MHA", + in_path=pathlib.Path(args.input).expanduser().resolve(), + out_path=pathlib.Path(args.output).expanduser().resolve(), + script_path=pathlib.Path(args.script).expanduser().resolve(), + python_bin=args.python, + log_path=pathlib.Path(args.log).expanduser().resolve(), + alias_map=ALIASES_MHA, + row_cmd_builder=build_cmd_mha, + print_table_enable=not args.no_print_table, + table_max_col_width=args.table_max_col_width, + table_columns=table_cols, + ) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index faa243b9..3028e91b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ build-backend = "setuptools.build_meta" [tool.setuptools.packages.find] where = ["."] -exclude = ["3rdparty", "3rdparty.*", "tests", "tests.*"] +exclude = ["3rdparty", "3rdparty.*", "tests", "tests.*", "profile", "profile.*"] [tool.yapf] based_on_style = "yapf" From 01f73542feffddf023a149257b9c718d37acc4ba Mon Sep 17 00:00:00 2001 From: RMLYC <472187190@qq.com> Date: Fri, 10 Oct 2025 16:49:54 +0800 Subject: [PATCH 4/8] add profile-test.sh and fix profile output log --- profile-test.sh | 214 ++++++++++++++++++++++++++++++++++++++++ profile/common_tools.py | 95 ++++++++++-------- top/kernel/mha.py | 8 +- 3 files changed, 274 insertions(+), 43 deletions(-) create mode 100755 profile-test.sh diff --git a/profile-test.sh b/profile-test.sh new file mode 100755 index 00000000..cebab83f --- /dev/null +++ b/profile-test.sh @@ -0,0 +1,214 @@ +#!/usr/bin/env bash +# profile-test.sh +# Orchestrates MHA and MHA-Decode sweeps and prints final CSVs as tables. + +set -euo pipefail + +############################################ +# User-editable settings # +############################################ + +# Python interpreter +PYTHON_BIN="python" + +# Sweep wrapper scripts +SWEEP_MHA="profile/mha_profile_test.py" +SWEEP_DECODE="profile/mha_decode_profile_test.py" + +# Underlying test scripts (passed into the sweep wrappers) +TEST_MHA="tests/test_mha.py" +TEST_DECODE="tests/test_mha_decode.py" + +# Input CSVs for each sweep +INPUT_MHA="profile/input_params/mha_params.csv" +INPUT_DECODE="profile/input_params/mha_decode_params.csv" + +# Terminal table rendering width +TABLE_MAX_COL_WIDTH=40 + +############################################ +# End of user-editable settings # +############################################ + +# ANSI colors +RED=$'\033[31m' +GREEN=$'\033[32m' +YELLOW=$'\033[33m' +CYAN=$'\033[36m' +BOLD=$'\033[1m' +RESET=$'\033[0m' + +# Fatal error helper +die() { echo "${RED}Error:${RESET} $*" >&2; exit 1; } + +# Section separator with optional title +sep() { + local title="${1:-}" + local line="==============================================================================" + if [[ -n "$title" ]]; then + echo "$line" + echo "${BOLD}${CYAN}$title${RESET}" + echo "$line" + else + echo "$line" + fi +} + +# Pretty-print a CSV file as a fixed-width ASCII table (embedded Python, no extra deps) +print_csv_as_table() { + local csv_path="$1" + local max_col_width="$2" + "$PYTHON_BIN" - "$csv_path" "$max_col_width" <<'PYCODE' +# -*- coding: utf-8 -*- +# Render a CSV as a fixed-width ASCII table (truncates long cells). + +import csv, sys, os +from typing import List, Dict + +def truncate(s: str, limit: int) -> str: + if limit <= 3 or len(s) <= limit: + return s[:limit] + return s[:limit-3] + "..." + +def col_widths(headers: List[str], rows: List[Dict[str,str]], maxw: int) -> List[int]: + w = [len(h) for h in headers] + for r in rows: + for i, h in enumerate(headers): + val = str(r.get(h, "")) + w[i] = min(max(w[i], len(val)), maxw) + return w + +def draw_row(cells: List[str], widths: List[int]) -> str: + parts = [] + for c, w in zip(cells, widths): + parts.append(" " + c.ljust(w) + " ") + return "|" + "|".join(parts) + "|" + +def draw_sep(widths: List[int]) -> str: + parts = [ "-" * (w + 2) for w in widths ] + return "+" + "+".join(parts) + "+" + +def main(): + if len(sys.argv) < 3: + print("Usage: