diff --git a/profile-test.sh b/profile-test.sh new file mode 100755 index 00000000..cebab83f --- /dev/null +++ b/profile-test.sh @@ -0,0 +1,214 @@ +#!/usr/bin/env bash +# profile-test.sh +# Orchestrates MHA and MHA-Decode sweeps and prints final CSVs as tables. + +set -euo pipefail + +############################################ +# User-editable settings # +############################################ + +# Python interpreter +PYTHON_BIN="python" + +# Sweep wrapper scripts +SWEEP_MHA="profile/mha_profile_test.py" +SWEEP_DECODE="profile/mha_decode_profile_test.py" + +# Underlying test scripts (passed into the sweep wrappers) +TEST_MHA="tests/test_mha.py" +TEST_DECODE="tests/test_mha_decode.py" + +# Input CSVs for each sweep +INPUT_MHA="profile/input_params/mha_params.csv" +INPUT_DECODE="profile/input_params/mha_decode_params.csv" + +# Terminal table rendering width +TABLE_MAX_COL_WIDTH=40 + +############################################ +# End of user-editable settings # +############################################ + +# ANSI colors +RED=$'\033[31m' +GREEN=$'\033[32m' +YELLOW=$'\033[33m' +CYAN=$'\033[36m' +BOLD=$'\033[1m' +RESET=$'\033[0m' + +# Fatal error helper +die() { echo "${RED}Error:${RESET} $*" >&2; exit 1; } + +# Section separator with optional title +sep() { + local title="${1:-}" + local line="==============================================================================" + if [[ -n "$title" ]]; then + echo "$line" + echo "${BOLD}${CYAN}$title${RESET}" + echo "$line" + else + echo "$line" + fi +} + +# Pretty-print a CSV file as a fixed-width ASCII table (embedded Python, no extra deps) +print_csv_as_table() { + local csv_path="$1" + local max_col_width="$2" + "$PYTHON_BIN" - "$csv_path" "$max_col_width" <<'PYCODE' +# -*- coding: utf-8 -*- +# Render a CSV as a fixed-width ASCII table (truncates long cells). + +import csv, sys, os +from typing import List, Dict + +def truncate(s: str, limit: int) -> str: + if limit <= 3 or len(s) <= limit: + return s[:limit] + return s[:limit-3] + "..." + +def col_widths(headers: List[str], rows: List[Dict[str,str]], maxw: int) -> List[int]: + w = [len(h) for h in headers] + for r in rows: + for i, h in enumerate(headers): + val = str(r.get(h, "")) + w[i] = min(max(w[i], len(val)), maxw) + return w + +def draw_row(cells: List[str], widths: List[int]) -> str: + parts = [] + for c, w in zip(cells, widths): + parts.append(" " + c.ljust(w) + " ") + return "|" + "|".join(parts) + "|" + +def draw_sep(widths: List[int]) -> str: + parts = [ "-" * (w + 2) for w in widths ] + return "+" + "+".join(parts) + "+" + +def main(): + if len(sys.argv) < 3: + print("Usage: