diff --git a/README.md b/README.md index 8847a29..bb4ab10 100644 --- a/README.md +++ b/README.md @@ -20,30 +20,39 @@ ______________________________________________________________________ ## 📰 News -- :fire: **2026-01-26 · [v0.1.2-alpha.1](https://github.com/tile-ai/TileRT/releases/tag/v0.1.2-alpha.1)**. **Multi-Token Prediction (MTP) lands in TileRT**. With mtp=3, we observe decoding rates up to **590 tokens/s** under synthetic workloads. +- :fire: **2026-02-14 · [Try the Online Demo](https://www.tilert.ai/)**. Our online demo is now live! Experience ultra-low-latency inference with **GLM-5** and **DeepSeek-V3.2**. [Try it now !](https://www.tilert.ai) + +- 🎉 **2026-02-14 · [v0.1.3](https://github.com/tile-ai/TileRT/releases/tag/v0.1.3) Released**. The v0.1.3 release introduces full support for the latest GLM-5 model, achieving up to 500 tokens/s on GLM-5-FP8 and up to 600 tokens/s on DeepSeek-V3.2. + +- 🚀 **2026-01-26 · [v0.1.2-alpha.1](https://github.com/tile-ai/TileRT/releases/tag/v0.1.2-alpha.1)**. **Multi-Token Prediction (MTP)** is now available in TileRT! With mtp=3, we achieve decoding rates of up to **590 tokens/s** under synthetic workloads. + +
+ Key Milestones - ⚡ **2025-12-23 · [v0.1.1](https://github.com/tile-ai/TileRT/releases/tag/v0.1.1)**. Achieved ~**35% further reduction** (3 ~ 4x speedup over baseline) in end-to-end token generation latency on a single node with **8× NVIDIA B200**. - 🚀 **2025-11-20 · [v0.1.0-alpha.1](https://github.com/tile-ai/TileRT/releases/tag/v0.1.0-alpha.1)**. Initial public release for **DeepSeek-V3.2-Exp**, targeting **ultra-low-latency** inference. Available on [PyPI](https://pypi.org/project/tilert) and [HuggingFace](https://huggingface.co/Tile-AI/DeepSeek-V3.2-Exp-TileRT). +
+ ______________________________________________________________________ -## TileRT: Pushing LLM Latency to the Limit +**TileRT** is a project designed to serve large language models (LLMs) in ultra-low-latency scenarios. Its goal is to push the latency limits of LLMs without compromising model size or quality—enabling models with hundreds of billions of parameters to achieve millisecond-level time per output token (TPOT). + +In our latest **v0.1.3** release, we tested **TileRT's** performance on the newest [**GLM-5**](https://huggingface.co/zai-org/GLM-5-FP8) model, demonstrating the effectiveness of our approach in real-world applications. We were among the first to support this latest model, validating the power of the technology we've developed. -TileRT is an experimental project exploring core compiler techniques for serving large language models (LLMs) in **ultra-low-latency** scenarios. Its goal is to push the latency limits of LLMs without compromising model size or quality—for example, enabling models with hundreds of billions of parameters to achieve millisecond-level **time per output token (TPOT)**. +Using the [**GLM-5**](https://huggingface.co/zai-org/GLM-5-FP8) model (without lossy optimizations such as quantization or distillation) with a batch size of 1 on 8× NVIDIA B200 GPUs, we evaluated TileRT’s preliminary performance. As shown in the benchmarks below, TileRT demonstrates substantial improvements over existing inference systems.

-TileRT Benchmark
-Figure 1. Sequence generation with TileRT, now enhanced with Multi-Token Prediction (MTP) to accelerate inference. +TileRT Benchmark
+Figure 1. Evaluation setup. Batch size: 1; Input sequence length: 1K, 16K, 32K, 64K, 128K, 150K, 192K; Output sequence length: 1K; Benchmark with synthetic data. SGLang v0.5.9.dev0 with MTP=3; vLLM v0.16.0rc2.dev173 with MTP=1 (vLLM failed when MTP=3, so we set MTP=1 as vLLM-GPT5-recipe); TileRT v0.1.3 with MTP=3.

-We evaluated TileRT’s preliminary performance using the [**DeepSeek-V3.2-Exp**](https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp) model (without lossy optimizations such as quantization or distillation) with a batch size of 1 on 8× NVIDIA B200 GPUs. As shown in the benchmark below, TileRT demonstrates substantial improvements over existing inference systems. -

-TileRT Benchmark
-Figure 2. Evaluation setup. Batch size: 1, Input sequence length/Output sequence length: 1K/1K; SGLang v0.5.6, TensorRT-LLM v1.2.0-rc5, vLLM v0.13.0, TileRT v0.1.1 with CUDA 12.9. +TileRT Benchmark
+Figure 2. Evaluation setup. Batch size: 1; Input sequence length: 1K, 16K, 32K, 64K, 128K, 150K, 192K; Output sequence length: 1K; Benchmark with synthetic data. SGLang v0.5.9.dev0; vLLM v0.16.0rc2.dev173; TileRT v0.1.3.

Unlike traditional inference systems optimized for high-throughput batch processing, TileRT prioritizes **responsiveness**, which is critical for applications such as high-frequency trading, interactive AI, real-time decision-making, long-running agents, and AI-assisted coding, where the latency of individual requests matters most. @@ -117,36 +126,46 @@ You're now ready to use TileRT! Proceed to the [Getting Started](#getting-starte ## Getting Started -### Download Pre-Converted Weights from HuggingFace +### Step 1: Download Official Model Weights + +Starting from release v0.1.3, TileRT no longer requires downloading pre-converted weights from Hugging Face. Instead, you can download the official model weights directly from the model's source (e.g., Hugging Face), and then convert them using the weight converter script included with the latest TileRT release. -TileRT requires preprocessing of the original DeepSeek-V3.2-Exp model weights before they can be used for ultra-low-latency inference. -To simplify this process, we provide **pre-converted weights** directly on HuggingFace so users do not need to run the preprocessing pipeline themselves. +### Step 2: Convert Weights Using `weight_converter.py` -You can download the weights using one of the recommended methods below: +After downloading the official model weights, you can use the following command to convert them into a format compatible with TileRT: -#### Option 1: Using `huggingface-cli` (recommended) +For **DeepSeek-V3.2**, run: ```bash -hf download Tile-AI/DeepSeek-V3.2-Exp-TileRT --local-dir ./tilert_weights +python -m tilert.models.preprocess.weight_converter \ + --model_type deepseek-v32 \ + --model_dir "/path/to/DeepSeek-V3.2" \ + --save_dir "/path/to/DeepSeek-V3.2-TileRT" ``` -This will download all files into the `./tilert_weights` directory. +Replace `/path/to/DeepSeek-V3.2` with the directory where you've downloaded the model weights, and `/path/to/DeepSeek-V3.2-TileRT` with the directory where you'd like the converted weights to be saved. -#### Option 2: Using Git + Git LFS +Similarly, for **GLM-5**, run: ```bash -git lfs install -git clone https://huggingface.co/Tile-AI/DeepSeek-V3.2-Exp-TileRT +python -m tilert.models.preprocess.weight_converter \ + --model_type glm-5 \ + --model_dir "/path/to/GLM-5-FP8" \ + --save_dir "/path/to/GLM-5-FP8-TileRT" ``` -For additional download methods or advanced usage, please refer to the official Hugging Face documentation. +Replace `/path/to/GLM-5-FP8` with the directory containing the downloaded GLM-5 model weights, and `/path/to/GLM-5-FP8-TileRT` with the desired location for saving the converted weights. + +### Step 3: Set the Converted Weights Directory -After downloading the weights, point TileRT to the directory using: +Once the weights are converted, set the environment variable to point TileRT to the directory containing the converted weights: ```bash -export MODEL_WEIGHTS_DIR=/path/to/tilert_weights +export MODEL_WEIGHTS_DIR= ... # converted weights ``` +Now you're ready to use TileRT with the converted weights! + ### Running the Generation Example After downloading the model weights, you can run the generation example within the Docker environment as follows: @@ -203,11 +222,6 @@ This example demonstrates basic single-step autoregressive generation using the ### Running the Generation Example with Multi-Token Prediction (MTP) -> \[!IMPORTANT\] -> **Weights update required for MTP.** Multi-Token Prediction (MTP) introduces additional **MTP heads** in the model weights. -> If you were using TileRT **before v0.1.1**, please make sure you download the **latest weights** from Hugging Face. -> Older weights do not include the required MTP heads and will fail to run when MTP is enabled. - TileRT also supports Multi-Token Prediction (MTP), which allows the model to generate multiple tokens per forward pass and reduces sequential decoding depth. To better illustrate MTP behavior, we use a longer prompt that encourages extended generation: diff --git a/assets/generate.gif b/assets/generate.gif deleted file mode 100644 index 3d73a90..0000000 Binary files a/assets/generate.gif and /dev/null differ diff --git a/assets/glm5-mtp.png b/assets/glm5-mtp.png new file mode 100644 index 0000000..d9ebb32 Binary files /dev/null and b/assets/glm5-mtp.png differ diff --git a/assets/glm5-without-mtp.png b/assets/glm5-without-mtp.png new file mode 100644 index 0000000..28ecf08 Binary files /dev/null and b/assets/glm5-without-mtp.png differ diff --git a/assets/logo.png b/assets/logo.png index b89b5f0..88af2ec 100644 Binary files a/assets/logo.png and b/assets/logo.png differ diff --git a/assets/perf.png b/assets/perf.png deleted file mode 100644 index c9804ad..0000000 Binary files a/assets/perf.png and /dev/null differ diff --git a/python/__init__.py b/python/__init__.py index 400d6a0..dbda493 100644 --- a/python/__init__.py +++ b/python/__init__.py @@ -50,7 +50,6 @@ def _load_library(filename: str) -> Any: from . import models # noqa: E402 -from .generate import ShowHandsGenerator # noqa: E402 from .models import deepseek_v3_2 # noqa: E402 from .tilert_init import tilert_init # noqa: E402 @@ -59,6 +58,5 @@ def _load_library(filename: str) -> Any: "tilert_init", "models", "deepseek_v3_2", - "ShowHandsGenerator", "__version__", ] diff --git a/python/benchmark/__init__.py b/python/benchmark/__init__.py new file mode 100644 index 0000000..49a349d --- /dev/null +++ b/python/benchmark/__init__.py @@ -0,0 +1,129 @@ +"""Benchmark suite for TileRT generation.""" + +from dataclasses import dataclass +from typing import TypeAlias + +from tilert.models.deepseek_v3_2.generator import DSAv32Generator +from tilert.models.glm_5.generator import GLM5Generator + +Generator: TypeAlias = DSAv32Generator | GLM5Generator + + +@dataclass +class BenchMode: + """Configuration for a single benchmark mode.""" + + with_mtp: bool + label: str + # Sampling parameters — None means keep current generator defaults (top-k1 argmax). + use_topp: bool = False + top_p: float = 1.0 + top_k: int = 256 + temperature: float = 1.0 + + +@dataclass +class CellStats: + """Stats for a single table cell (one mode x one benchmark column).""" + + tok_s: float = 0.0 + ms: float = 0.0 + acc_rate: str = "-" + + +BenchStats = dict[str, dict[str, CellStats]] + + +def apply_mode(generator: Generator, mode: BenchMode) -> None: + """Apply sampling parameters for a benchmark mode.""" + generator.update_sampling_params( + temperature=mode.temperature, + top_p=mode.top_p, + top_k=mode.top_k, + use_topp=mode.use_topp, + ) + + +def merge_stats(stats_list: list[BenchStats]) -> BenchStats: + """Merge multiple benchmark stats dicts by mode label.""" + merged: BenchStats = {} + for stats in stats_list: + for mode, cols in stats.items(): + merged.setdefault(mode, {}).update(cols) + return merged + + +def _fmt(number: float, suffix: str) -> str: + return f"{number:.3f} {suffix}" + + +def print_summary_table( + all_stats: BenchStats, + model_name: str, +) -> None: + """Print a markdown summary table from merged benchmark stats. + + Each mode occupies 3 rows: tok/s, ms, acc_rate. + """ + if not all_stats: + return + + # Collect column keys in insertion order (preserves benchmark ordering) + col_keys: list[str] = [] + for cols in all_stats.values(): + for k in cols: + if k not in col_keys: + col_keys.append(k) + + ROW_LABELS = ["tok/s", "ms", "acc"] + + # Build formatted cell strings: {mode: {col: [row0, row1, row2]}} + formatted: dict[str, dict[str, list[str]]] = {} + for mode, cols in all_stats.items(): + formatted[mode] = {} + for k in col_keys: + cell = cols.get(k) + if cell is None: + formatted[mode][k] = ["-", "-", "-"] + else: + formatted[mode][k] = [ + _fmt(cell.tok_s, "tok/s"), + _fmt(cell.ms, "ms"), + cell.acc_rate, + ] + + # Compute column widths + col_widths: dict[str, int] = {} + for k in col_keys: + w = len(k) + for mode_cells in formatted.values(): + for row_str in mode_cells.get(k, ["-"]): + w = max(w, len(row_str)) + col_widths[k] = w + + mode_width = max(len("Mode"), max(len(m) for m in all_stats)) + # Row label column shares the mode column; pick wider of mode names vs row labels + mode_width = max(mode_width, max(len(r) for r in ROW_LABELS)) + + print(f"\n## Benchmark Summary ({model_name})\n") + + # Header + hdr = [f" {'Mode':<{mode_width}} "] + hdr += [f" {k:<{col_widths[k]}} " for k in col_keys] + print("|" + "|".join(hdr) + "|") + + # Separator + sep = ["-" * (mode_width + 2)] + sep += ["-" * (col_widths[k] + 2) for k in col_keys] + print("|" + "|".join(sep) + "|") + + # Data rows: 3 rows per mode + mode_list = list(all_stats.keys()) + for _, mode in enumerate(mode_list): + for row_idx, _row_label in enumerate(ROW_LABELS): + label = mode if row_idx == 0 else "" + cells = [f" {label:<{mode_width}} "] + for k in col_keys: + cell_text = formatted[mode][k][row_idx] + cells.append(f" {cell_text:<{col_widths[k]}} ") + print("|" + "|".join(cells) + "|") diff --git a/python/benchmark/coding_prompt.py b/python/benchmark/coding_prompt.py new file mode 100644 index 0000000..e4ff6ed --- /dev/null +++ b/python/benchmark/coding_prompt.py @@ -0,0 +1,46 @@ +"""Coding-prompt benchmark: single generation, measures coding task throughput.""" + +from typing import cast + +import numpy as np +from benchmark import BenchMode, BenchStats, CellStats, Generator, apply_mode + +PROMPT = "Hi, can you write a sort program in C for me?" + + +def run(generator: Generator, modes: list[BenchMode]) -> BenchStats: + """Run the coding-prompt benchmark for each mode. + + Returns stats with column: Coding. + """ + stats: BenchStats = {} + + for mode in modes: + apply_mode(generator, mode) + print(f"\n--- Coding-prompt benchmark ({mode.label}) ---") + print(f"Prompt: {PROMPT}") + print("Completion:") + + _, time_list, accepted_counts = cast( + tuple[str, list[float], list[int]], + generator.generate(PROMPT, True, with_mtp=mode.with_mtp), + ) + + mode_stats: dict[str, CellStats] = {} + + if mode.with_mtp and accepted_counts: + total_tokens = sum(accepted_counts) + total_time = sum(time_list) + speed = total_tokens / total_time if total_time > 0 else 0 + avg_ms = total_time / len(time_list) * 1000 + avg_a = total_tokens / len(accepted_counts) + acc_rate = f"{avg_a:.2f}/{min(accepted_counts)}/{max(accepted_counts)}" + mode_stats["Coding"] = CellStats(tok_s=speed, ms=avg_ms, acc_rate=acc_rate) + elif time_list: + mean_time = float(np.mean(time_list)) + speed = 1 / mean_time + mode_stats["Coding"] = CellStats(tok_s=speed, ms=mean_time * 1000) + + stats[mode.label] = mode_stats + + return stats diff --git a/python/benchmark/long_prompt.py b/python/benchmark/long_prompt.py new file mode 100644 index 0000000..f6d4d0e --- /dev/null +++ b/python/benchmark/long_prompt.py @@ -0,0 +1,46 @@ +"""Long-prompt benchmark: single generation, measures long-form throughput.""" + +from typing import cast + +import numpy as np +from benchmark import BenchMode, BenchStats, CellStats, Generator, apply_mode + +PROMPT = "Hi, can you tell me a very long story, with roughly 3000 words?" + + +def run(generator: Generator, modes: list[BenchMode]) -> BenchStats: + """Run the long-prompt benchmark for each mode. + + Returns stats with column: Long. + """ + stats: BenchStats = {} + + for mode in modes: + apply_mode(generator, mode) + print(f"\n--- Long-prompt benchmark ({mode.label}) ---") + print(f"Prompt: {PROMPT}") + print("Completion:") + + _, time_list, accepted_counts = cast( + tuple[str, list[float], list[int]], + generator.generate(PROMPT, True, with_mtp=mode.with_mtp), + ) + + mode_stats: dict[str, CellStats] = {} + + if mode.with_mtp and accepted_counts: + total_tokens = sum(accepted_counts) + total_time = sum(time_list) + speed = total_tokens / total_time if total_time > 0 else 0 + avg_ms = total_time / len(time_list) * 1000 + avg_a = total_tokens / len(accepted_counts) + acc_rate = f"{avg_a:.2f}/{min(accepted_counts)}/{max(accepted_counts)}" + mode_stats["Long"] = CellStats(tok_s=speed, ms=avg_ms, acc_rate=acc_rate) + elif time_list: + mean_time = float(np.mean(time_list)) + speed = 1 / mean_time + mode_stats["Long"] = CellStats(tok_s=speed, ms=mean_time * 1000) + + stats[mode.label] = mode_stats + + return stats diff --git a/python/benchmark/short_prompt.py b/python/benchmark/short_prompt.py new file mode 100644 index 0000000..bebd2ce --- /dev/null +++ b/python/benchmark/short_prompt.py @@ -0,0 +1,89 @@ +"""Short-prompt benchmark: 20 iterations, measures steady-state decode throughput.""" + +from typing import cast + +import numpy as np +from benchmark import BenchMode, BenchStats, CellStats, Generator, apply_mode + +PROMPT = "Tell me 10 jokes, keep them all under 100 words." +NUM_ITERS = 20 +TOKEN_CHECKPOINTS = [200] + + +def run(generator: Generator, modes: list[BenchMode]) -> BenchStats: + """Run the short-prompt benchmark for each mode. + + Returns stats with columns: Short@ for each checkpoint. + """ + stats: BenchStats = {} + + for mode in modes: + apply_mode(generator, mode) + print(f"\n--- Short-prompt benchmark ({mode.label}) ---", flush=True) + + all_times: list[list[float]] = [] + all_accepted: list[list[int]] = [] + all_results: list[str] = [] + for _iter in range(NUM_ITERS): + if _iter % 5 == 0: + print(f" iter {_iter}/{NUM_ITERS}...", flush=True) + result, time_list, accepted_counts = cast( + tuple[str, list[float], list[int]], + generator.generate(PROMPT, False, with_mtp=mode.with_mtp), + ) + all_times.append(time_list) + all_accepted.append(accepted_counts) + all_results.append(result) + + # Verify determinism and print output once + mismatches = [i for i, r in enumerate(all_results) if r != all_results[0]] + if mismatches: + print(f" WARNING: non-deterministic output at iters {mismatches}") + print(f"Prompt: {PROMPT}") + print(f"Completion:\n{all_results[0]}") + + mode_stats: dict[str, CellStats] = {} + + if mode.with_mtp: + for token_num in TOKEN_CHECKPOINTS: + speeds: list[float] = [] + for time_list, accepted_list in zip(all_times, all_accepted): + if time_list and accepted_list: + cumsum_tokens = np.cumsum(accepted_list) + cumsum_times = np.cumsum(time_list) + idx = int(np.searchsorted(cumsum_tokens, token_num)) + # If total tokens < token_num, use all available data + if idx >= len(cumsum_times): + idx = len(cumsum_times) - 1 + tok_count = int(cumsum_tokens[idx]) + elapsed = float(cumsum_times[idx]) + if elapsed > 0: + speeds.append(tok_count / elapsed) + if speeds: + speed = float(np.mean(speeds)) + mean_time = 1 / speed + + flat_accepted = [a for al in all_accepted for a in al] + acc_rate = "-" + if flat_accepted: + avg_a = sum(flat_accepted) / len(flat_accepted) + acc_rate = f"{avg_a:.2f}/{min(flat_accepted)}/{max(flat_accepted)}" + + mode_stats[f"Short@{token_num}"] = CellStats( + tok_s=speed, ms=mean_time * 1000, acc_rate=acc_rate + ) + else: + for token_num in TOKEN_CHECKPOINTS: + per_token_times = [] + for time_list in all_times: + trimmed = time_list[:token_num] + if trimmed: + per_token_times.extend(trimmed) + if per_token_times: + mean_time = float(np.mean(per_token_times)) + speed = 1 / mean_time + mode_stats[f"Short@{token_num}"] = CellStats(tok_s=speed, ms=mean_time * 1000) + + stats[mode.label] = mode_stats + + return stats diff --git a/python/generate.py b/python/generate.py index 79f61b7..5724e8e 100644 --- a/python/generate.py +++ b/python/generate.py @@ -1,11 +1,58 @@ """Text generation script for TileRT.""" from argparse import ArgumentParser -from typing import cast -import numpy as np - -from tilert.models.deepseek_v3_2.dsa_show_hands import ShowHandsGenerator +from benchmark import BenchMode +from benchmark import coding_prompt as coding_bench +from benchmark import long_prompt as long_bench +from benchmark import merge_stats, print_summary_table +from benchmark import short_prompt as short_bench + +from tilert.models.deepseek_v3_2.generator import DSAv32Generator +from tilert.models.deepseek_v3_2.model_args import ModelArgs as DSAv32ModelArgs +from tilert.models.glm_5.generator import GLM5Generator +from tilert.models.glm_5.model_args import ModelArgsGLM5 + + +def get_generator( + model_type: str, + max_new_tokens: int, + temperature: float, + model_weights_dir: str, + with_mtp: bool, + top_p: float = 0.9, + top_k: int = 256, + enable_thinking: bool = False, + sampling_seed: int = 42, +) -> DSAv32Generator | GLM5Generator: + """Get the appropriate generator based on model type.""" + assert model_type in ["deepseek_v3_2", "glm5"] + if model_type == "deepseek_v3_2": + model_args = DSAv32ModelArgs() + return DSAv32Generator( + model_args=model_args, + max_new_tokens=max_new_tokens, + temperature=temperature, + model_weights_dir=model_weights_dir, + with_mtp=with_mtp, + top_p=top_p, + top_k=top_k, + use_topp=top_p < 1.0, + sampling_seed=sampling_seed, + ) + model_args = ModelArgsGLM5() + return GLM5Generator( + model_args=model_args, + max_new_tokens=max_new_tokens, + temperature=temperature, + model_weights_dir=model_weights_dir, + with_mtp=with_mtp, + top_p=top_p, + top_k=top_k, + use_topp=top_p < 1.0, + enable_thinking=enable_thinking, + sampling_seed=sampling_seed, + ) def parse_args(): # type: ignore @@ -16,8 +63,22 @@ def parse_args(): # type: ignore required=True, help="Path to model weights directory", ) + parser.add_argument( + "--model", + type=str, + default="deepseek_v3_2", + choices=["deepseek_v3_2", "glm5"], + help="Model type to use (default: deepseek_v3_2)", + ) parser.add_argument("--max-new-tokens", type=int, default=4000, help="Max tokens to generate") - parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature") + parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature") + parser.add_argument( + "--top-p", + type=float, + default=1.0, + help="Top-p (nucleus) sampling threshold. Use < 1.0 to enable top-p sampling (e.g. 0.9)", + ) + parser.add_argument("--top-k", type=int, default=256, help="Top-k sampling threshold") parser.add_argument("--interactive", action="store_true") parser.add_argument( "--with-mtp", @@ -29,6 +90,17 @@ def parse_args(): # type: ignore action="store_true", help="Use random weights instead of pretrained (for testing MTP without real weights)", ) + parser.add_argument( + "--enable-thinking", + action="store_true", + help="Enable thinking mode in chat template", + ) + parser.add_argument( + "--sampling-seed", + type=int, + default=42, + help="Sampling seed for top-p sampling (fixed per request, default: 42)", + ) return parser.parse_args() @@ -37,31 +109,47 @@ def parse_args(): # type: ignore usage: execute below command under tilert root directory: - # Standard generation with pretrained weights: + # DeepSeek V3.2 - Standard generation with pretrained weights: python python/generate.py --model-weights-dir "xxxx" 2>&1 | tee test.log - # MTP generation with random weights (for testing): + # DeepSeek V3.2 - MTP generation with random weights (for testing): python python/generate.py --model-weights-dir "xxxx" --with-mtp \ --use-random-weights 2>&1 | tee test.log - # MTP generation with pretrained weights (when available): + # DeepSeek V3.2 - MTP generation with pretrained weights (when available): python python/generate.py --model-weights-dir "xxxx" --with-mtp 2>&1 | tee test.log + + # GLM5 - Standard generation with random weights (for testing): + python python/generate.py --model glm5 --model-weights-dir "xxxx" \ + --use-random-weights 2>&1 | tee test.log + + # GLM5 - Standard generation with pretrained weights: + python python/generate.py --model glm5 --model-weights-dir "xxxx" 2>&1 | tee test.log + + # GLM5 - MTP generation with random weights (for testing): + python python/generate.py --model glm5 --model-weights-dir "xxxx" --with-mtp \ + --use-random-weights 2>&1 | tee test.log + + # GLM5 - MTP generation with pretrained weights: + python python/generate.py --model glm5 --model-weights-dir "xxxx" --with-mtp \ + 2>&1 | tee test.log """ args = parse_args() - generator: ShowHandsGenerator = ShowHandsGenerator( + generator = get_generator( + model_type=args.model, max_new_tokens=args.max_new_tokens, temperature=args.temperature, model_weights_dir=args.model_weights_dir, - with_mtp=args.with_mtp, + with_mtp=args.with_mtp if args.interactive else True, + top_p=args.top_p, + top_k=args.top_k, + enable_thinking=args.enable_thinking, + sampling_seed=args.sampling_seed, ) - if args.use_random_weights: - print("Initializing with random weights...") - generator.init_random_weights() - else: - print("Loading pretrained weights...") - generator.from_pretrained() + print("Loading pretrained weights...") + generator.from_pretrained() # simple memoryless interactive mode if args.interactive: @@ -72,72 +160,33 @@ def parse_args(): # type: ignore break _ = generator.generate(prompt) # type: ignore[has-type] else: - # This prompt is to test the model’s ability to follow instructions - # (in terms of quantity, type, and length) while keeping it fun. - print("==== Performance ====") - prompt = "Tell me 10 jokes, keep them all under 100 words." - print("Prompt:", prompt) - all_times = [] - all_accepted = [] - for _iter in range(20): - if _iter % 5 == 0: - print(f"Executing iter {_iter}...") - results, time_list, accepted_counts = cast( - tuple[str, list[float], list[int]], - generator.generate(prompt, False), # type: ignore[has-type] - ) - all_times.append(time_list) - all_accepted.append(accepted_counts) - - if args.with_mtp: - for token_num in range(100, 300, 100): - times_to_token_num = [] - for time_list, accepted_list in zip(all_times, all_accepted): - if len(time_list) > 5 and len(accepted_list) > 5: - times = time_list[5:] - accepted = accepted_list[5:] - cumsum_tokens = np.cumsum(accepted) - cumsum_times = np.cumsum(times) - # Find index where we reach token_num tokens - idx = np.searchsorted(cumsum_tokens, token_num) - if idx < len(cumsum_times): - times_to_token_num.append(cumsum_times[idx]) - if times_to_token_num: - mean_total_time = np.mean(times_to_token_num) - mean_time = mean_total_time / token_num - speed = 1 / mean_time - out_str = ( - f"**Perf@{token_num}: {speed:.3f} tokens/s & " - f"{(mean_time * 1000):.3f} ms**" - ) - print(out_str) - - # Print accepted tokens statistics - flat_accepted = [a for accepted_list in all_accepted for a in accepted_list] - if flat_accepted: - avg_accepted = sum(flat_accepted) / len(flat_accepted) - min_accepted = min(flat_accepted) - max_accepted = max(flat_accepted) - print( - f"**Accepted length: mean={avg_accepted:.2f}, " - f"min={min_accepted}, max={max_accepted}**" - ) - else: - all_times_np = np.array(all_times) - for token_num in range(100, 300, 100): - mean_time = np.mean(all_times_np[..., 5:token_num]) - speed = 1 / mean_time - out_str = ( - f"**Perf@{token_num}: {speed:.3f} tokens/s & {(mean_time * 1000):.3f} ms**" - ) - print(out_str) - print(results) - - # This prompt is used to test long sequence generation - prompt = "Hi, can you tell me a very long story, with roughly 3000 words?" - print("Prompt:", prompt) - print("Completion:") - completion, _, _ = generator.generate(prompt) # type: ignore[has-type] + + # 3 modes: top-k1 w/o MTP, top-k1 w/ MTP, top-p0.95 w/ MTP + modes = [ + BenchMode(with_mtp=False, label="top-k1 w/o MTP"), + BenchMode(with_mtp=True, label="top-k1 w/ MTP"), + BenchMode( + with_mtp=True, + label="top-p0.95 w/ MTP", + use_topp=True, + top_p=0.95, + top_k=args.top_k, + temperature=args.temperature, + ), + ] + + # Run all benchmarks and collect stats + all_bench_stats = [ + short_bench.run(generator, modes), + coding_bench.run(generator, modes), + long_bench.run(generator, modes), + ] + + # Print unified summary table + print_summary_table( + merge_stats(all_bench_stats), + model_name=args.model.upper(), + ) print("Cleaning up...") generator.cleanup() diff --git a/python/models/base.py b/python/models/base.py index b8a8219..58171a7 100644 --- a/python/models/base.py +++ b/python/models/base.py @@ -2,6 +2,7 @@ import os from abc import ABC, abstractmethod +from enum import Enum from typing import Any import torch @@ -9,8 +10,7 @@ from tilert import logger from tilert.models.deepseek_config import get_rank, get_world_size -from tilert.models.deepseek_v3_2.params import BaseParams -from tilert.models.preprocess import WeightLoader +from tilert.models.deepseek_v3_2.model_args import ModelArgs from tilert.utils import get_profile_log_tensor __all__ = [ @@ -18,6 +18,18 @@ ] +class TilertWeightsConverter: + """Tilert weights converter""" + + def __init__(self, model_args: ModelArgs, num_devices: int): + self.model_args = model_args + self.num_devices = num_devices + + def dispatch(self, algorithm: Enum, weights: list[torch.Tensor]) -> Any: + dispatch_method = getattr(self, f"convert_to_{algorithm.value}") + return dispatch_method(weights) + + class TileRTModule(nn.Module, ABC): """Base class for all TileRT modules. @@ -33,6 +45,9 @@ def __init__( tilert_weights_dir: str = "", layer_idx: int = 0, compute_kernel_type: str = "bf16", + model_args: ModelArgs | None = None, + num_devices: int = 8, + device_id: int = 0, *args: Any, **kwargs: Any, ) -> None: @@ -49,6 +64,17 @@ def __init__( """ super().__init__(*args, **kwargs) + self.model_args = model_args if model_args is not None else ModelArgs() + self.num_devices = num_devices + self.device_id = device_id + self.algorithm: Enum | None = None + + self.is_var_init = False + self.is_tilert_weights_init = False + self.is_ref_weights_init = False + + self.profile_logs: torch.Tensor | None = None + self.layer_idx = layer_idx self.flag_enable_tilert = False @@ -69,95 +95,32 @@ def __init__( self.golden_weights_dir = golden_weights_dir self.tilert_weights_dir = tilert_weights_dir - self.weight_loader = WeightLoader( - layer_idx=layer_idx, - golden_weights_dir=golden_weights_dir, - tilert_weights_dir=tilert_weights_dir, - ) - self.profile_logs = get_profile_log_tensor() - # members for debugging - self.temp_dir = os.path.join(os.path.expanduser("~"), ".cache", "tilert") - os.makedirs(self.temp_dir, exist_ok=True) - self.tmp_vars: dict[str, torch.Tensor] = {} + def get_cache_vars(self) -> list[torch.Tensor]: + return [] - def register_weights(self, weights_config: dict[str, dict[str, Any]]) -> None: - """Register weights configuration. + def get_tilert_weights_alias(self) -> list[str]: + return list(self.tilert_weights_alias()) - Args: - weights_config: Dictionary mapping weight names to their configurations. - """ - self.weight_loader.register_weights(weights_config) + def get_ref_weights_alias(self) -> list[str]: + return list(self.ref_weights_alias()) - def load_weights(self, device_id: int = 0) -> None: - """Load weights from the weights path.""" - golden_weights_path = self.weight_loader.get_weight_file_path( - device_id=device_id, is_tilert=False - ) - self.weight_loader.load_weights(weights_path=golden_weights_path, device_id=device_id) - - def load_tilert_weights(self, device_id: int = 0) -> None: - """Load tilert weights from the weights path.""" - tilert_weights_path = self.weight_loader.get_weight_file_path( - device_id=device_id, is_tilert=True - ) - self.weight_loader.load_tilert_weights( - weights_path=tilert_weights_path, device_id=device_id - ) - - def get_weight(self, name: str, from_tilert: bool = False) -> Any: - """Get a weight by name. - - Args: - name: Weight name. - from_tilert: Whether to get the weight from tilert. - """ - return self.weight_loader.get_weight(name, from_tilert) - - def wrap_var_name(self, var_name: str) -> str: - """Wrap the variable name. + def set_algorithm(self, algorithm: Enum) -> None: + """Set the algorithm for the module. Args: - var_name: Variable name. + algorithm: Algorithm. """ - return f"layer_{self.layer_idx}_{var_name}" - - def register_tmp_var(self, var_name: str, var_tensor: torch.Tensor) -> None: - """Register a temporary variable for debugging. - - Args: - var_name: Variable name. - var_tensor: Variable. - """ - self.tmp_vars[self.wrap_var_name(var_name)] = var_tensor - - def register_tmp_vars(self, var_dict: dict[str, torch.Tensor]) -> None: - """Register a list of temporary variables for debugging. + self.algorithm = algorithm - Args: - var_dict: Dictionary of variable names and variables. - """ - for var_name, tensor in var_dict.items(): - self.register_tmp_var(var_name, tensor) - - def dump_tmp_vars( - self, tmp_vars: dict[str, torch.Tensor] | None = None, save_dir: str = "" - ) -> None: - """Dump variables to the profile log file. + def register_weights(self, weights_config: dict[str, dict[str, Any]]) -> None: + """Register weights configuration. Args: - tensor_vars: Dictionary of variable names and tensors. - save_dir: Directory to save the variables. + weights_config: Dictionary mapping weight names to their configurations. """ - if tmp_vars is None: - tmp_vars = self.tmp_vars - save_dir = self.temp_dir if save_dir == "" else save_dir - os.makedirs(save_dir, exist_ok=True) - - for tensor_name in tmp_vars: - logger.info(f"Saving variable {tensor_name} to {save_dir}") - torch.save(tmp_vars[tensor_name], os.path.join(save_dir, f"{tensor_name}.pt")) + self.weight_loader.register_weights(weights_config) def get_profile_log_path(self) -> str: """Get the path to the profile log file. @@ -216,17 +179,6 @@ def tilert_forward(self, *args: Any, **kwargs: Any) -> Any: # noqa: U100 del args, kwargs raise NotImplementedError("Tilert forward not implemented") - @abstractmethod - def to_tilert_weights(self, *args: Any, **kwargs: Any) -> BaseParams | None: - """Convert weights to tilert. - - Args: - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - """ - del args, kwargs - raise NotImplementedError("Convert weights to tilert not implemented") - def enable_profiling_log(self, enable: bool = True) -> None: """Enable profiling log for this module and all submodules. @@ -256,3 +208,120 @@ def enable_tilert(self, enable: bool = True) -> None: # type: ignore module.flag_enable_tilert = enable if enable: module.to_tilert_weights() + + +class SerializableTileRTModule(TileRTModule): + """Serializable TileRT module.""" + + def __init__( + self, model_args: ModelArgs, device_id: int, num_devices: int, remove_selected: bool = False + ): + super().__init__( + type(self).__name__, model_args=model_args, device_id=device_id, num_devices=num_devices + ) + self.remove_selected = remove_selected + + self.exec_seq: list[TileRTModule] = [] + self.prefix_seq: list[str] = [] + self.suffix_seq: list[str] = [] + self.retain_weights_seq: list[bool] = [] + + def get_cache_vars(self) -> list[torch.Tensor]: + cache_vars = [] + for op in self.exec_seq: + cache_vars.extend(op.get_cache_vars()) + return cache_vars + + def register_op( + self, op: TileRTModule, prefix: str = "", suffix: str = "", retain_weights: bool = False + ) -> None: + self.exec_seq.append(op) + self.prefix_seq.append(prefix) + self.suffix_seq.append(suffix) + self.retain_weights_seq.append(retain_weights) + + def get_tilert_weights_alias(self) -> list[str]: + weights_alias: list[str] = [] + for op in self.exec_seq: + weights_alias.extend(op.get_tilert_weights_alias()) + return weights_alias + + def get_ref_weights_alias(self) -> list[str]: + weights_alias: list[str] = [] + for op in self.exec_seq: + weights_alias.extend(op.get_ref_weights_alias()) + return weights_alias + + def get_weights_list(self) -> list[torch.Tensor]: + weights = [] + for op in self.exec_seq: + weights.extend(op.get_weights_list()) + return weights + + def device_sharding(self, raw_weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + sharded_weights_map: dict[str, torch.Tensor] = {} + for op in self.exec_seq: + sharded_weights_map.update(op.device_sharding(raw_weights_map)) + return sharded_weights_map + + @property + def tilert_tensor_alias(self) -> list[str]: + """Return tilert tensor alias of the first sub-op (RMSNormProjxWqkvia).""" + tensor_alias: list[str] = [] + for op in self.exec_seq: + tensor_alias.extend(op.tilert_weights_alias()) + return tensor_alias + + @property + def ref_tensor_alias(self) -> list[str]: + """Return reference tensor alias of the first sub-op (RMSNormProjxWqkvia).""" + tensor_alias: list[str] = [] + for op in self.exec_seq: + tensor_alias.extend(op.ref_weights_alias()) + return tensor_alias + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + for op, prefix, suffix, retain_weights in zip( + self.exec_seq, self.prefix_seq, self.suffix_seq, self.retain_weights_seq + ): + keys_to_remove = set() + op_state_dict = {} + for op_key in op.get_tilert_weights_alias(): + original_key = f"{prefix}{op_key}{suffix}" + op_state_dict[op_key] = state_dict[original_key] + if self.remove_selected: + keys_to_remove.add(original_key) + op.init_tilert_weights(op_state_dict) + if self.remove_selected and not retain_weights: + for k in keys_to_remove: + del state_dict[k] + + def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + for op in self.exec_seq: + op.init_reference_weights(state_dict) + + def init_random_weights(self) -> None: + for op in self.exec_seq: + op.init_random_weights() + + def init_tilert_vars(self, batch_size: int, seq_len: int) -> None: + for op in self.exec_seq: + op.init_tilert_vars(batch_size, seq_len) + + def golden_forward( + self, + x: torch.Tensor, + pe_cache: torch.Tensor, + start_pos: int, + ) -> Any: + del x, pe_cache, start_pos + raise NotImplementedError("Golden forward is not implemented") + + def tilert_forward( + self, + x: torch.Tensor, + pe_cache: torch.Tensor, + start_pos: int, + ) -> Any: + del x, pe_cache, start_pos + raise NotImplementedError("Tilert forward is not implemented") diff --git a/python/models/common.py b/python/models/common.py new file mode 100644 index 0000000..b213793 --- /dev/null +++ b/python/models/common.py @@ -0,0 +1,361 @@ +from typing import cast + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +__all__ = [ + "init_func", + "linear", + "Linear", + "RMSNorm", + "LayerNorm", + "ColumnParallelLinear", + "RowParallelLinear", + "ParallelEmbedding", +] + +from tilert.models.deepseek_config import ( + block_size, + gemm_impl, + get_rank, + get_world_size, + is_distributed, +) +from tilert.models.deepseek_v3_2.refs.kernel import act_quant, fp8_gemm, weight_dequant + + +def _get_scale_tensor(tensor: torch.Tensor) -> torch.Tensor: + """Return the dynamically attached ``scale`` tensor.""" + scale = getattr(tensor, "scale", None) + if scale is None: + raise AttributeError("Expected quantized tensor to carry a 'scale' attribute.") + return cast(torch.Tensor, scale) + + +def init_func(x_in: torch.Tensor) -> torch.Tensor: + x_dtype = x_in.dtype + x_fp32 = x_in.to(torch.float32) + if x_fp32.dim() >= 2: + initial_tensor = nn.init.kaiming_uniform_(x_fp32) + else: + initial_tensor = nn.init.uniform_(x_fp32) + return initial_tensor.to(x_dtype) + + +def linear( + x_in: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, + scale_fmt: str | None = None, +) -> torch.Tensor: + """ + Applies a linear transformation to the incoming data: y = xA^T + b. + + This function supports specialized implementations based on quantization + and tensor formats. + + Args: + x_in (torch.Tensor): The input tensor. + weight (torch.Tensor): The weight tensor. It may be quantized and + requires dequantization for certain cases. + bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None. + + Returns: + torch.Tensor: The result of the linear transformation, which may involve + quantization-aware computations depending on the input parameters. + + Notes: + - If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version is used + for computation. + - If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied. + - For other cases, the function applies quantization to `x` and uses `fp8_gemm` + for computation. + """ + if weight.element_size() > 1: + return F.linear(x_in, weight, bias) + if gemm_impl == "bf16": + weight = weight_dequant(weight, _get_scale_tensor(weight)) + return F.linear(x_in, weight, bias) + + x_quant: torch.Tensor + scale: torch.Tensor + x_quant, scale = act_quant(x_in, block_size, scale_fmt) + y_out: torch.Tensor = fp8_gemm(x_quant, scale, weight, _get_scale_tensor(weight)) + if bias is not None: + y_out += bias + return y_out + + +class Linear(nn.Module): + """ + Custom linear layer with support for quantized weights and optional bias. + + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. + """ + + dtype = torch.bfloat16 + scale_fmt: str | None = None + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + dtype: torch.dtype | None = None, + weight: torch.Tensor | None = None, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + if weight is not None: + self.weight = torch.nn.Parameter(weight) + else: + self.weight = nn.Parameter( + init_func(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)) + ) + + if self.weight.element_size() == 1: + scale_out_features = (out_features + block_size - 1) // block_size + scale_in_features = (in_features + block_size - 1) // block_size + scale_param = nn.Parameter( + init_func( + torch.empty( + scale_out_features, + scale_in_features, + dtype=torch.float32, + ) + ) + ) + self.scale = scale_param + self.weight.scale = scale_param # type: ignore[attr-defined] + else: + self.register_parameter("scale", None) + + if bias: + self.bias = nn.Parameter(init_func(torch.empty(out_features))) + else: + self.register_parameter("bias", None) + + def forward(self, x_in: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the custom linear layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor after linear computation. + """ + return linear(x_in, self.weight, self.bias, self.scale_fmt) + + +class RMSNorm(nn.Module): + """ + Root Mean Square Layer Normalization (RMSNorm). + + Args: + dim (int): Dimension of the input tensor. + eps (float): Epsilon value for numerical stability. Defaults to 1e-6. + """ + + def __init__(self, dim: int, eps: float = 1e-6, weight: torch.Tensor | None = None): + super().__init__() + self.dim = dim + self.eps = eps + + if weight is None: + self.weight = nn.Parameter(init_func(torch.empty(dim, dtype=torch.float32))) + else: + self.weight = torch.nn.Parameter(weight) + + def forward( + self, x: torch.Tensor, residual: torch.Tensor | None = None + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for RMSNorm. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Normalized tensor with the same shape as input. + """ + dtype = torch.bfloat16 # x.dtype + if residual is None: + x = x.float() + var_s = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(var_s + self.eps) + return (self.weight * x).to(dtype) + + x = residual = x.float() + residual.float() + var_s = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(var_s + self.eps) + return (self.weight * x).to(dtype), residual.to(dtype) + + +class LayerNorm(nn.Module): + """Layer Normalization.""" + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.layer_norm(x.float(), (self.dim,), self.weight, self.bias, self.eps).type_as(x) + + +class ColumnParallelLinear(Linear): + """ + Column parallel linear layer. + + Linear layer with column parallelism, splitting output features across + distributed processes. + + Args: + in_features (int): Number of input features. + out_features (int): Total number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + dtype: torch.dtype | None = None, + ): + world_size = get_world_size() + assert ( + out_features % world_size == 0 + ), f"Output features must be divisible by world size {world_size}" + self.part_out_features = out_features // world_size + super().__init__(in_features, self.part_out_features, bias, dtype) + + def forward(self, x_in: torch.Tensor) -> torch.Tensor: + """ + Forward pass for column parallel linear layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor with column-parallel computation. + """ + return linear(x_in, self.weight, self.bias) + + +class RowParallelLinear(Linear): + """ + Linear layer with row parallelism, splitting input features across distributed processes. + + Args: + in_features (int): Total number of input features. + out_features (int): Number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + reduce_output: bool = True, + dtype: torch.dtype | None = None, + ): + + self.world_size = get_world_size() + + if in_features % self.world_size != 0: + raise ValueError( + f"Input features must be divisible by world size (world_size={self.world_size})" + ) + + self.part_in_features = in_features // self.world_size + self.reduce_output = reduce_output + + super().__init__(self.part_in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for row parallel linear layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor with row-parallel computation. + """ + y = linear(x, self.weight, None, self.scale_fmt) + if self.reduce_output and is_distributed() > 1: + y = y.float() + dist.all_reduce(y) + if self.bias is not None: + y += self.bias + return y.type_as(x) + + +class ParallelEmbedding(nn.Module): + """ + Parallel embedding layer. + + Embedding layer with parallelism support across distributed processes. + + Args: + vocab_size (int): Vocabulary size. + dim (int): Embedding dimension. + """ + + def __init__(self, vocab_size: int, dim: int): + super().__init__() + self.vocab_size = vocab_size + self.dim = dim + + self.world_size = get_world_size() + self.rank = get_rank() + + assert ( + vocab_size % self.world_size == 0 + ), f"Vocabulary size must be divisible by world size {self.world_size}" + + self.part_vocab_size = vocab_size // self.world_size + self.vocab_start_idx = self.rank * self.part_vocab_size + self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size + + self.weight = nn.Parameter(init_func(torch.empty(self.part_vocab_size, self.dim))) + + def forward(self, x_in: torch.Tensor) -> torch.Tensor: + """ + Forward pass for parallel embedding layer. + + Args: + x (torch.Tensor): Input tensor containing token indices. + + Returns: + torch.Tensor: Embedded representations. + + Raises: + ValueError: If `world_size` is not defined. + """ + if self.world_size > 1: + mask = (x_in < self.vocab_start_idx) | (x_in >= self.vocab_end_idx) + x_in = x_in - self.vocab_start_idx + x_in[mask] = 0 + + y_out = F.embedding(x_in, self.weight) + + if is_distributed(): + y_out[mask] = 0 + dist.all_reduce(y_out) + return y_out diff --git a/python/models/deepseek_v3_2/dsa_mtp_e2e_show_hands.py b/python/models/deepseek_v3_2/dsa_mtp_e2e_show_hands.py deleted file mode 100644 index 3c3dcac..0000000 --- a/python/models/deepseek_v3_2/dsa_mtp_e2e_show_hands.py +++ /dev/null @@ -1,158 +0,0 @@ -"""DSA MTP E2E show hands for DeepSeek v3.2.""" - -from typing import Any - -import torch - -from tilert.models.deepseek_v3_2.dsa_show_hands import ShowHandsDSALayer - -__all__ = [ - "ShowHandsDsaMtpE2eLayer", - "dsa_mtp_e2e_show_hands_prepare_money", - "dsa_mtp_e2e_show_hands", - "dsa_mtp_e2e_show_hands_reset", - "dsa_mtp_e2e_show_hands_go_home", -] - - -def dsa_mtp_e2e_show_hands_prepare_money( - params: list[torch.Tensor], - temp_vars: list[torch.Tensor], - cache_vars: list[torch.Tensor], - profile_logs: torch.Tensor, -) -> Any: - """Prepare money for MTP E2E show hands.""" - return torch.ops.tilert.dsa_mtp_e2e_show_hands_prepare_money( - params, temp_vars, cache_vars, profile_logs - ) - - -def dsa_mtp_e2e_show_hands(draft_tokens: torch.Tensor) -> Any: - """Show hands with MTP E2E.""" - return torch.ops.tilert.dsa_mtp_e2e_show_hands(draft_tokens) - - -def dsa_mtp_e2e_show_hands_reset(placeholder: torch.Tensor) -> Any: - """Reset MTP E2E show hands.""" - return torch.ops.tilert.dsa_mtp_e2e_show_hands_reset(placeholder) - - -def dsa_mtp_e2e_show_hands_go_home(placeholder: torch.Tensor) -> Any: - """Cleanup MTP E2E show hands.""" - return torch.ops.tilert.dsa_mtp_e2e_show_hands_go_home(placeholder) - - -def dsa_mtp_e2e_show_hands_set_prefill_valid_tokens( - placeholder: torch.Tensor, num_valid_tokens: int -) -> Any: - """Set the number of valid (non-padding) tokens for prefill mode. - - This controls how many tokens are copied from draft_tokens to predicted_tokens - during prefill. Should be called before forward() when the chunk has padding. - - Args: - placeholder: Placeholder tensor for PyTorch dispatch (not used). - num_valid_tokens: Number of valid tokens in the chunk (1-4). - """ - return torch.ops.tilert.dsa_mtp_e2e_show_hands_set_prefill_valid_tokens( - placeholder, num_valid_tokens - ) - - -class ShowHandsDsaMtpE2eLayer(ShowHandsDSALayer): - """Show hands DSA MTP E2E layer for DeepSeek v3.2. - - Inherits from ShowHandsDSALayer and adds MTP-specific functionality. - """ - - # MTP constants - NUM_MTP = 3 - MTP_SEQ_LEN = NUM_MTP + 1 # 4 - - def __init__( - self, - max_seq_len: int, - with_weight_conversion: bool = True, - ) -> None: - super().__init__( - max_seq_len=max_seq_len, - model_path="", - with_weight_conversion=with_weight_conversion, - with_mtp=True, - ) - - def _get_num_cache_layers(self) -> int: - """Return number of cache layers (+1 for shared MTP cache).""" - return int(self.NUM_LAYERS) + 1 - - def _prepare_money( - self, - params: list[torch.Tensor], - intermediates: list[torch.Tensor], - caches: list[torch.Tensor], - profile_logs: torch.Tensor, - ) -> None: - """Prepare money for MTP E2E show hands.""" - dsa_mtp_e2e_show_hands_prepare_money(params, intermediates, caches, profile_logs) - - def _show_hands(self, draft_tokens: torch.Tensor) -> Any: - """Run MTP E2E show hands forward.""" - return dsa_mtp_e2e_show_hands(draft_tokens.cpu()) - - def _reset_sequence_impl(self) -> None: - """Reset MTP E2E sequence.""" - dsa_mtp_e2e_show_hands_reset(self.placeholder) - - def _cleanup_impl(self) -> None: - """Cleanup MTP E2E resources.""" - dsa_mtp_e2e_show_hands_go_home(self.placeholder) - - def set_prefill_valid_tokens(self, num_valid_tokens: int) -> None: - """Set the number of valid tokens for prefill mode. - - This controls how many tokens are copied from draft_tokens to predicted_tokens - during prefill. Should be called before forward() when the chunk has padding. - - Args: - num_valid_tokens: Number of valid tokens in the chunk (1-4). - """ - dsa_mtp_e2e_show_hands_set_prefill_valid_tokens(self.placeholder, num_valid_tokens) - - def get_next_draft_tokens(self, device_id: int = 0) -> torch.Tensor: - """Get next_draft_tokens from the specified device. - - Args: - device_id: Device ID to get results from. - - Returns: - next_draft_tokens tensor of shape [1, MTP_SEQ_LEN]. - """ - intermediates, _, _, _ = self._get_device_result(device_id) - # next_draft_tokens is at index 38 (DsaTempVars::kNextDraftTokensIdx) - return intermediates[38] - - def get_num_accepted(self, device_id: int = 0) -> int: - """Get number of accepted tokens from the specified device. - - Args: - device_id: Device ID to get results from. - - Returns: - Number of accepted tokens. - """ - intermediates, _, _, _ = self._get_device_result(device_id) - # accepted_tokens (num_accepted) is at index 37 (DsaTempVars::kAcceptedTokensIdx) - return int(intermediates[37][0].item()) - - def get_predicted_tokens(self, device_id: int = 0) -> torch.Tensor: - """Get predicted_tokens from the specified device. - - Args: - device_id: Device ID to get results from. - - Returns: - predicted_tokens tensor containing main model predictions. - """ - intermediates, _, _, _ = self._get_device_result(device_id) - # predicted_tokens is at index 35 (DsaTempVars::kPredictedTokensIdx) - return intermediates[35] diff --git a/python/models/deepseek_v3_2/dsa_show_hands.py b/python/models/deepseek_v3_2/dsa_show_hands.py deleted file mode 100644 index ca781a2..0000000 --- a/python/models/deepseek_v3_2/dsa_show_hands.py +++ /dev/null @@ -1,1040 +0,0 @@ -"""DSA show hands for deepseek v3.2.""" - -import glob -import json -import math -import os -import sys -import threading -import time -from typing import Any - -import torch -from safetensors.torch import load_file -from transformers import AutoTokenizer - -from tilert import logger -from tilert.models.base import TileRTModule -from tilert.models.deepseek_v3_2.model_args import ModelArgs as ModelArgsV3_2 -from tilert.models.deepseek_v3_2.params import ( - BaseParams, - CacheVars, - DenseLayerParamsKeys, - Dsa671BModelInitializer, - IntermediateMapper, - MoELayerParamsKeys, - TempVars, -) -from tilert.models.preprocess.weight_utils import ( - DownAllreduceWeightsConverter, - ExpertSelectUpGateSiLUWeightsConverter, - RMSNormHeadProjWeightsConverter, - RMSNormProjQAKVAKIWeightsConverter, - RMSNormUpGateSiLUWeightsConverter, - UnProjOAllreduceWeightsConverter, -) -from tilert.models.utils import precompute_freqs_cis -from tilert.tilert_init import tilert_init -from tilert.utils import get_profile_log_tensor - -__all__ = [ - "ShowHandsGenerator", -] - -# MTP layer ID constant -MTP_LAYER_ID = 61 - -# MTP params keys order (for layer 61) -MTPPreprocessParamsKeys = [ - "embedding_rmsnorm_gamma", - "hidden_rmsnorm_gamma", - "eh_proj_weights", -] - -MTPMlaParamsKeys = [ - "x_rmsnorm_gamma", - "qkv_wa_weights", - "qkv_wa_scales", - "k_weights", - "k_bias", - "q_rmsnorm_gamma", - "q_wb_weights", - "q_wb_scales", - "id_score_weights", - "wkv_b1_weights", - "wkv_b1_scales", - "kv_rmsnorm_gamma", - "wkv_b2_weights", - "wkv_b2_scales", - "unproj_weights", - "unproj_scales", -] - -MTPMoeParamsKeys = [ - "unproj_o_gamma", - "exp_proj_weights", - "exp_bias", - "exp_upgate_weights", - "exp_upgate_scales", - "exp_down_weights", - "exp_down_scales", -] - - -def stats_time(time_list: list[float], title: str) -> None: - if len(time_list) > 0: - avg_time = sum(time_list) / len(time_list) - std_dev = math.sqrt(sum((x - avg_time) ** 2 for x in time_list) / len(time_list)) - logger.info(title) - logger.info(f"--Average time taken to generate token: {avg_time * 1000:.4f} ms") - logger.info(f"--Standard deviation of time: {std_dev * 1000:.4f} ms") - logger.info(f"--Effective tokens per second: {1 / avg_time:.4f}") - - -DeviceResult = tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], torch.Tensor] - - -def dsa_show_hands_prepare_money( - params: list[torch.Tensor], - temp_vars: list[torch.Tensor], - cache_vars: list[torch.Tensor], - profile_logs: torch.Tensor, - forward_max_seq_len: int, -) -> Any: - """Prepare money for show hands""" - return torch.ops.tilert.dsa_show_hands_prepare_money( - params, temp_vars, cache_vars, profile_logs, forward_max_seq_len - ) - - -def dsa_show_hands(token_id: torch.Tensor) -> Any: - """Show hands with native MT""" - return torch.ops.tilert.dsa_show_hands(token_id) - - -def dsa_show_hands_reset(placeholder: torch.Tensor) -> Any: - """Reset show one hand""" - return torch.ops.tilert.dsa_show_hands_reset(placeholder) - - -def dsa_show_hands_go_home(placeholder: torch.Tensor) -> Any: - """Go home""" - return torch.ops.tilert.dsa_show_hands_go_home(placeholder) - - -# Put dirty conversion code here. -# TODO: better way to handle conversion. -def _convert_weights_on_demand( - state_dicts: dict[str, torch.Tensor], - skip_mtp_layer: bool = False, -) -> dict[str, torch.Tensor]: - """Convert weights on demand. - - Args: - state_dicts: Dictionary of weights to convert. - skip_mtp_layer: If True, skip layer 61 (MTP layer) weights. - """ - res_dicts = {} - for key, value in state_dicts.items(): - # Skip layer 61 (MTP layer) MLA/MoE/preprocess weights if requested, - # but NOT lm_head and model.norm.weight (used by main model head) - if ( - skip_mtp_layer - and f"layer_{MTP_LAYER_ID}_" in key - and "lm_head.weight" not in key - and "model.norm.weight" not in key - ): - res_dicts[key] = value - continue - - if "qkv_wa_weights" in key: # first op - weight_key = key - scale_key = key.replace("qkv_wa_weights", "qkv_wa_scales") - gamma_key = key.replace("qkv_wa_weights", "x_rmsnorm_gamma") - common_weights = RMSNormProjQAKVAKIWeightsConverter.tilert_to_common( - state_dicts[weight_key], - state_dicts[scale_key], - state_dicts[gamma_key], - ) - conv_weights = ( - RMSNormProjQAKVAKIWeightsConverter.common_to_tilert_native_bf16_warp_gemv( - *common_weights - ) - ) - res_dicts[key] = conv_weights[0] - elif "unproj_weights" in key: # unprojo_allreduce op - weight_key = key - scale_key = key.replace("unproj_weights", "unproj_scales") - weights, scales = UnProjOAllreduceWeightsConverter.tilert_to_tilert_112sm_mma( - state_dicts[weight_key], - state_dicts[scale_key], - ) - res_dicts[weight_key] = weights - res_dicts[scale_key] = scales - state_dicts[weight_key] = None - elif "unproj_scales" in key: # skip unprojo_allreduce op:: scales - pass - elif "exp_upgate_weights" in key: # expert select up gate silu op - weight_key = key - scale_key = key.replace("exp_upgate_weights", "exp_upgate_scales") - weights_and_scales = ExpertSelectUpGateSiLUWeightsConverter.tilert_to_tilert_144sm_mma( - state_dicts[weight_key], state_dicts[scale_key] - ) - res_dicts[key] = weights_and_scales - state_dicts[weight_key] = None - - elif "upgate_weights" in key: # rmsnorm up gate silu op - weight_key = key - scale_key = key.replace("upgate_weights", "upgate_scales") - weights_and_scales = RMSNormUpGateSiLUWeightsConverter.tilert_to_tilert_144sm_mma( - state_dicts[weight_key], - state_dicts[scale_key], - ) - res_dicts[key] = weights_and_scales - state_dicts[weight_key] = None - elif "down_weights" in key: # expert down allreduce op - weight_key = key - scale_key = key.replace("down_weights", "down_scales") - weights_swizzled, scales = DownAllreduceWeightsConverter.tilert_to_tilert_mma( - state_dicts[weight_key], - state_dicts[scale_key], - ) - res_dicts[weight_key] = weights_swizzled - res_dicts[scale_key] = scales - state_dicts[weight_key] = None - elif "lm_head.weight" in key: # head projection weights - weights = RMSNormHeadProjWeightsConverter.tilert_to_tilert_native_bf16_warp_gemv( - state_dicts[key] - ) - res_dicts[key] = weights - state_dicts[key] = None - else: - res_dicts[key] = value - - return res_dicts - - -def _convert_mtp_weights_on_demand( - state_dicts: dict[str, torch.Tensor], -) -> dict[str, torch.Tensor]: - """Convert MTP layer weights on demand to TileRT optimized format. - - This applies conversions specifically for MTP layer weights (layer 61). - Only processes keys that contain 'layer_61_'. - Note: lm_head and model.norm.weight are reused from main model (already converted). - """ - res_dicts = {} - for key, value in state_dicts.items(): - # Only process layer 61 (MTP layer) weights - if f"layer_{MTP_LAYER_ID}_" not in key: - res_dicts[key] = value - continue - - # Skip lm_head and model.norm.weight - they're reused from main model - # and already converted by _convert_weights_on_demand - if "lm_head.weight" in key or "model.norm.weight" in key: - res_dicts[key] = value - continue - - if "qkv_wa_weights" in key: # first op - weight_key = key - scale_key = key.replace("qkv_wa_weights", "qkv_wa_scales") - gamma_key = key.replace("qkv_wa_weights", "x_rmsnorm_gamma") - common_weights = RMSNormProjQAKVAKIWeightsConverter.tilert_to_common( - state_dicts[weight_key], - state_dicts[scale_key], - state_dicts[gamma_key], - ) - conv_weights = ( - RMSNormProjQAKVAKIWeightsConverter.common_to_tilert_native_bf16_warp_gemv( - *common_weights - ) - ) - res_dicts[key] = conv_weights[0] - elif "unproj_weights" in key: # unproj_o_allreduce op - weight_key = key - scale_key = key.replace("unproj_weights", "unproj_scales") - weights, scales = UnProjOAllreduceWeightsConverter.tilert_to_tilert_112sm_mma( - state_dicts[weight_key], - state_dicts[scale_key], - ) - res_dicts[weight_key] = weights - res_dicts[scale_key] = scales - elif "unproj_scales" in key: # skip - already processed with unproj_weights - pass - elif "exp_upgate_weights" in key: # expert select up gate silu op - weight_key = key - scale_key = key.replace("exp_upgate_weights", "exp_upgate_scales") - weights_and_scales = ExpertSelectUpGateSiLUWeightsConverter.tilert_to_tilert_144sm_mma( - state_dicts[weight_key], state_dicts[scale_key] - ) - res_dicts[key] = weights_and_scales - elif "exp_down_weights" in key: # expert down allreduce op - weight_key = key - scale_key = key.replace("exp_down_weights", "exp_down_scales") - weights_swizzled, scales = DownAllreduceWeightsConverter.tilert_to_tilert_mma( - state_dicts[weight_key], - state_dicts[scale_key], - ) - res_dicts[weight_key] = weights_swizzled - res_dicts[scale_key] = scales - else: - res_dicts[key] = value - - return res_dicts - - -class ShowHandsDSALayer(TileRTModule): - """Show hands DSA for deepseek v3.2.""" - - NUM_DENSE_LAYERS = 3 - NUM_MOE_LAYERS = 58 - NUM_LAYERS = NUM_DENSE_LAYERS + NUM_MOE_LAYERS - - def __init__( - self, - max_seq_len: int, - model_path: str = "", - with_weight_conversion: bool = True, - with_mtp: bool = False, - ) -> None: - super().__init__() - self.hidden_size = 7168 - self.forward_max_seq_len = 4 # max supported seq_len per forward - self.batch_size = 1 - self.num_heads = 16 - - self.q_dim = 1536 - self.kv_dim = 512 - self.k_pe_dim = 64 - - self.q_pe_lora_dim = 64 - self.q_pe_dim = 512 - self.q_nope_dim = 128 - - self.v_head_dim = 128 - - self.n_routed_experts = 256 - self.n_activate_experts = 8 - self.exp_dims = 256 - - self.max_seq_len = max_seq_len - - self.vocab_size_full = 129280 - self.vocab_size = self.vocab_size_full // 8 # 16160 - - self.num_devices = 8 - - self.model_path = model_path - self.with_weight_conversion = with_weight_conversion - self.with_mtp = with_mtp - - self.multi_devices_results: list[DeviceResult | None] = [None] * torch.cuda.device_count() - - self.kv_cache = torch.zeros( - self.batch_size, self.max_seq_len, self.kv_dim, dtype=torch.bfloat16, device="cuda:0" - ) - self.pe_cache = torch.zeros( - self.batch_size, self.max_seq_len, self.k_pe_dim, dtype=torch.bfloat16, device="cuda:0" - ) - self.k_cache = torch.zeros( - self.batch_size, self.max_seq_len, 128, dtype=torch.bfloat16, device="cuda:0" - ) - - self.placeholder = torch.zeros(1, 1, dtype=torch.int32, device="cpu") - - def _get_num_cache_layers(self) -> int: - """Return number of cache layers. Override in subclass for MTP.""" - return self.NUM_LAYERS - - def _prepare_money( - self, - params: list[torch.Tensor], - intermediates: list[torch.Tensor], - caches: list[torch.Tensor], - profile_logs: torch.Tensor, - ) -> None: - """Prepare money for show hands. Override in subclass for MTP.""" - dsa_show_hands_prepare_money( - params, intermediates, caches, profile_logs, self.forward_max_seq_len - ) - - def _show_hands(self, token_id: torch.Tensor) -> Any: - """Run show hands forward. Override in subclass for MTP.""" - return dsa_show_hands(token_id.cpu()) - - def _reset_sequence_impl(self) -> None: - """Reset sequence implementation. Override in subclass for MTP.""" - dsa_show_hands_reset(self.placeholder) - - def _cleanup_impl(self) -> None: - """Cleanup implementation. Override in subclass for MTP.""" - dsa_show_hands_go_home(self.placeholder) - - def golden_forward(self) -> None: - raise NotImplementedError("golden_forward not implemented") - - def tilert_forward(self) -> None: - raise NotImplementedError("tilert_forward not implemented") - - def to_tilert_weights(self) -> BaseParams: - raise NotImplementedError("to_tilert_weights not implemented") - - def get_mla_moe_layer_params_dict( - self, layer_id: int, device: torch.device, dev_attrs: dict - ) -> dict[str, torch.Tensor]: - del dev_attrs - dsa_671b_model = Dsa671BModelInitializer( - torch.device(device), - with_weight_conversion=self.with_weight_conversion, - ) - return { - **dsa_671b_model.init_mla_params().to_dict(layer_id, device), - **dsa_671b_model.init_moe_params().to_dict(layer_id, device), - } - - def get_mla_mlp_layer_params_dict( - self, layer_id: int, device: torch.device, dev_attrs: dict - ) -> dict[str, torch.Tensor]: - del dev_attrs - dsa_671b_model = Dsa671BModelInitializer( - torch.device(device), - with_weight_conversion=self.with_weight_conversion, - ) - return { - **dsa_671b_model.init_mla_params().to_dict(layer_id, device), - **dsa_671b_model.init_mlp_params().to_dict(layer_id, device), - } - - def get_llm_head_layer_params_dict( - self, layer_id: int, device: torch.device, dev_attrs: dict - ) -> dict[str, torch.Tensor]: - del dev_attrs - dsa_671b_model = Dsa671BModelInitializer( - torch.device(device), - with_weight_conversion=self.with_weight_conversion, - ) - return {**dsa_671b_model.init_llm_head_params().to_dict(layer_id, device)} - - def get_temp_vars(self, device: torch.device, dev_attrs: dict) -> TempVars: - del dev_attrs - dsa_671b_model = Dsa671BModelInitializer( - torch.device(device), - with_weight_conversion=self.with_weight_conversion, - with_mtp=self.with_mtp, - ) - return dsa_671b_model.acquire_temp_vars() - - def get_cache_vars(self, device: torch.device, dev_attrs: dict) -> CacheVars: - del dev_attrs - return CacheVars( - torch.zeros_like(self.k_cache).to(device), - torch.zeros_like(self.kv_cache).to(device), - torch.zeros_like(self.pe_cache).to(device), - ) - - def _gen_freqs_cis(self) -> torch.Tensor: - freqs_cis = precompute_freqs_cis(ModelArgsV3_2()) - return torch.view_as_real(freqs_cis).reshape(freqs_cis.shape[0], -1) - - def get_dev_id(self, weight_name: str) -> int: - line_splits = weight_name.split("_dev_") - if len(line_splits) == 2: - return int(line_splits[1]) - - return -1 - - def get_weight_files(self, weight_map: dict[str, str], device_id: int) -> list[str]: - """Get the weight files for the given device.""" - weight_files = [] # to preserve the order of weight files - weight_files_set_ = set() # to avoid duplicate weight files - for weight in weight_map: - dev_id = self.get_dev_id(weight) - if dev_id == -1 or dev_id != device_id: - continue - - weight_file = weight_map[weight] - if weight_file in weight_files_set_: - continue - weight_files_set_.add(weight_file) - weight_files.append(weight_file) - - return weight_files - - def load_embedding_weights(self, model_path: str, device_id: int) -> torch.Tensor: - """Load the embedding weights for the given device.""" - # Look up the embedding file from index.json instead of hardcoding - index_file = os.path.join(model_path, "model.safetensors.index.json") - with open(index_file, encoding="utf-8") as f: - weights_index = json.load(f) - weight_map = weights_index["weight_map"] - - embed_key = "model.embed_tokens.weight" - if embed_key not in weight_map: - raise ValueError(f"Embedding weight {embed_key} not found in index.json") - - embed_weights_file = weight_map[embed_key] - embed_weights_file_path = os.path.join(model_path, embed_weights_file) - state_dict = load_file(embed_weights_file_path, device=f"cuda:{device_id}") - return state_dict[embed_key] - - def get_total_shards(self, model_path: str) -> int: - """Get the total number of shards by counting safetensors files in the directory.""" - safetensors_files = glob.glob(os.path.join(model_path, "*.safetensors")) - return len(safetensors_files) - - def _init_weights(self, model_path: str | None) -> None: - """Load the model weights from the given path or generate random weights.""" - - def _load_state_dicts(model_path: str, dev_attrs: dict) -> dict[str, torch.Tensor]: - device_id = dev_attrs["device"] - index_file = "model.safetensors.index.json" - with open(os.path.join(model_path, index_file), encoding="utf-8") as f: - weights_index = json.load(f) - weight_map = weights_index["weight_map"] - - weight_files = self.get_weight_files(weight_map, device_id) - state_dicts = {} - for weight_file in weight_files: - state_dict = load_file( - os.path.join(model_path, weight_file), device=f"cuda:{device_id}" - ) - logger.info(f"Loaded weights from {weight_file} for device {device_id}") - state_dicts.update(state_dict) - embed_weights = self.load_embedding_weights(model_path, device_id) - state_dicts["model.embed_tokens.weight"] = embed_weights - return state_dicts - - def _gen_state_dicts_with_random_weights(dev_attrs: dict) -> dict[str, torch.Tensor]: - device_id = dev_attrs["device"] - state_dicts = {} - for layer_id in range(self.NUM_DENSE_LAYERS): - state_dicts.update( - self.get_mla_mlp_layer_params_dict(layer_id, device_id, dev_attrs) - ) - for layer_id in range(3, 3 + self.NUM_MOE_LAYERS): - state_dicts.update( - self.get_mla_moe_layer_params_dict(layer_id, device_id, dev_attrs) - ) - state_dicts.update(self.get_llm_head_layer_params_dict(61, device_id, dev_attrs)) - dsa_671b_model = Dsa671BModelInitializer( - torch.device(device_id), - with_weight_conversion=self.with_weight_conversion, - ) - state_dicts.update(dsa_671b_model.init_embedding_params().to_dict(device_id)) - return state_dicts - - def __load_weights(device_id: int, model_path: str | None) -> None: - intermediates: list[torch.Tensor] = [] - caches: list[torch.Tensor] = [] - params: list[torch.Tensor] = [] - state_dicts = {} - dev_attrs = { - "device": device_id, - "dtype": torch.bfloat16, - } - start_time = time.time() - with torch.cuda.device(device_id): - intermediates.extend( - self.get_temp_vars( - device_id, dev_attrs - ).generate_params_with_continuous_storage(device_id) - ) - num_cache_layers = self._get_num_cache_layers() - for _ in range(num_cache_layers): - caches.extend(self.get_cache_vars(device_id, dev_attrs).get_params()) - logger.info(f"Created intermediates and caches for device {device_id}") - - load_from_path = bool(model_path and os.path.exists(model_path)) - if load_from_path: - assert model_path is not None # Type narrowing for mypy - state_dicts = _load_state_dicts(model_path, dev_attrs) - else: - state_dicts = _gen_state_dicts_with_random_weights(dev_attrs) - # Do necessary weight conversions only for loaded weights. - # Skip MTP layer (layer 61) conversion here if with_mtp is True, - # as it will be handled separately. - if load_from_path: - state_dicts = _convert_weights_on_demand( - state_dicts, skip_mtp_layer=self.with_mtp - ) - - for layer_id in range(self.NUM_DENSE_LAYERS): - # Each layer has its dedicated cache - for param in DenseLayerParamsKeys: - key_name = f"layer_{layer_id}_{param}_dev_{device_id}" - if key_name not in state_dicts: - raise ValueError(f"Weight {key_name} not found") - params.append(state_dicts[key_name]) - - for layer_id in range(3, 3 + self.NUM_MOE_LAYERS): - # Each layer has its dedicated cache - for param in MoELayerParamsKeys: - key_name = f"layer_{layer_id}_{param}_dev_{device_id}" - if key_name not in state_dicts: - raise ValueError(f"Weight {key_name} not found") - params.append(state_dicts[key_name]) - - # heads - head = f"layer_61_lm_head.weight_dev_{device_id}" - head_norm = f"layer_61_model.norm.weight_dev_{device_id}" - if head not in state_dicts: - raise ValueError(f"Weight {head} not found") - if head_norm not in state_dicts: - raise ValueError(f"Weight {head_norm} not found") - params.append(state_dicts[head_norm]) - params.append(state_dicts[head]) - - # embed_weights = self.load_embedding_weights(total_shards, device_id) - params.append(state_dicts["model.embed_tokens.weight"]) - - # RoPE frequencies - freqs_cis = self._gen_freqs_cis() - params.extend([freqs_cis.to(device_id)]) - - # Add MTP-specific params when with_mtp is True - if self.with_mtp: - if load_from_path: - # Load real MTP weights from state_dicts - # Convert MTP-specific weights (layer 61) - state_dicts = _convert_mtp_weights_on_demand(state_dicts) - - # MTP params order (matching C++ register_mtp.cu): - # 1. LlmPreprocessModule: 2 (embedding + freqs_cis) - # 2. MtpPreProcessLayer: 3 (embedding_rmsnorm_gamma, - # hidden_rmsnorm_gamma, eh_proj_weights) - # 3. MoeBlock: 23 (MLA 16 + MOE 7) - # 4. LlmHeadModule: 2 (hidden_rms_gamma, head_proj_weights) - - # 1. Embedding params (2) - params.append(state_dicts["model.embed_tokens.weight"]) - mtp_freqs_cis = self._gen_freqs_cis() - params.append(mtp_freqs_cis.to(device_id)) - - # 2. MTP preprocess params (3) - for key in MTPPreprocessParamsKeys: - full_key = f"layer_{MTP_LAYER_ID}_{key}_dev_{device_id}" - if full_key not in state_dicts: - raise ValueError(f"MTP weight {full_key} not found") - params.append(state_dicts[full_key]) - - # 3. MLA params (16) + MOE params (7) = MoeBlock (23) - for key in MTPMlaParamsKeys: - full_key = f"layer_{MTP_LAYER_ID}_{key}_dev_{device_id}" - if full_key not in state_dicts: - raise ValueError(f"MTP weight {full_key} not found") - params.append(state_dicts[full_key]) - - for key in MTPMoeParamsKeys: - full_key = f"layer_{MTP_LAYER_ID}_{key}_dev_{device_id}" - if full_key not in state_dicts: - raise ValueError(f"MTP weight {full_key} not found") - params.append(state_dicts[full_key]) - - # 4. LLM head params (2) - mtp_head_norm = f"layer_{MTP_LAYER_ID}_model.norm.weight_dev_{device_id}" - mtp_head = f"layer_{MTP_LAYER_ID}_lm_head.weight_dev_{device_id}" - if mtp_head_norm not in state_dicts: - raise ValueError(f"MTP weight {mtp_head_norm} not found") - if mtp_head not in state_dicts: - raise ValueError(f"MTP weight {mtp_head} not found") - params.append(state_dicts[mtp_head_norm]) - params.append(state_dicts[mtp_head]) - - logger.info(f"Loaded real MTP weights for device {device_id}") - else: - # Use random weights for MTP - dsa_671b_model = Dsa671BModelInitializer( - torch.device(device_id), - with_weight_conversion=self.with_weight_conversion, - with_mtp=True, - ) - # MTP needs: embedding, mtp_preprocess, mla, moe, llm_head params - params.extend(dsa_671b_model.init_embedding_params().get_params()) - params.extend(dsa_671b_model.init_mtp_preprocess_params().get_params()) - params.extend(dsa_671b_model.init_mla_params().get_params()) - params.extend(dsa_671b_model.init_moe_params().get_params()) - params.extend(dsa_671b_model.init_llm_head_params().get_params()) - - profile_logs = get_profile_log_tensor(device=device_id, num_max_insts=65536) - result = (intermediates, caches, params, profile_logs) - self.multi_devices_results[device_id] = result - - elapsed_time = time.time() - start_time - minutes = int(elapsed_time // 60) - seconds = int(elapsed_time % 60) - time_str = ( - f"{minutes} minutes {seconds} seconds" if minutes > 0 else f"{seconds} seconds" - ) - logger.info(f"Completed loading weights for device {device_id} in {time_str}") - - threads = [] - exceptions: list[Exception | None] = [None] * self.num_devices - for device_id in range(self.num_devices): - - def _runner(dev_id: int) -> None: - try: - __load_weights(dev_id, model_path) - except Exception as exc: # pragma: no cover - surfaced after join - exceptions[dev_id] = exc - - thread = threading.Thread(target=_runner, args=(device_id,)) - threads.append(thread) - thread.start() - for thread in threads: - thread.join() - for device_id, exc in enumerate(exceptions): - if exc is not None: - raise RuntimeError(f"Failed to initialize device {device_id}: {exc}") from exc - - # Prepare money for all devices - for device_id in range(self.num_devices): - with torch.cuda.device(device_id): - intermediates, caches, params, profile_logs = self._get_device_result(device_id) - self._prepare_money(params, intermediates, caches, profile_logs) - - def from_pretrained(self, model_path: str) -> None: - """Load the model weights from the given path.""" - if not os.path.exists(model_path): - raise ValueError(f"Model weights directory {model_path} does not exist") - self._init_weights(model_path) - - def init_random_weights(self) -> None: - """Generate random weights.""" - self._init_weights(None) - - def forward( - self, - token_id: torch.Tensor, - ) -> list[DeviceResult]: - self._show_hands(token_id) - return [self._get_device_result(device_id) for device_id in range(self.num_devices)] - - def reset_sequence(self) -> None: - self._reset_sequence_impl() - - def cleanup(self) -> None: - self._cleanup_impl() - - def __del__(self) -> None: - try: - self.cleanup() - except Exception as e: - print(f"Exception during cleanup: {e}", file=sys.stderr) - - def _get_device_result(self, device_id: int) -> DeviceResult: - device_result = self.multi_devices_results[device_id] - if device_result is None: - raise RuntimeError(f"Device {device_id} is not initialized") - return device_result - - -class ShowHandsGenerator: - def __init__( - self, - max_new_tokens: int = 100, - temperature: float = 1.0, - model_weights_dir: str = "", - with_mtp: bool = False, - ): - """Initialize the ShowHandsGenerator. - - Args: - max_new_tokens: Maximum number of new tokens to generate. Defaults to 100. - temperature: Temperature for sampling. Defaults to 1.0. - model_weights_dir: Path of the model weights directory. - with_mtp: Whether to use MTP (Multi-Token Prediction) for speculative decoding. - """ - torch.set_num_threads(64) - self.model_weights_dir = model_weights_dir - - self.max_new_tokens = max_new_tokens - self.temperature = temperature - self.with_mtp = with_mtp - - self.config = ModelArgsV3_2() - self.tokenizer = AutoTokenizer.from_pretrained(self.model_weights_dir) - self.eos_id = self.tokenizer.eos_token_id - self.batch_size = 1 # fixed batch size to 1 for now - - self.default_device = torch.device("cuda:0") - - if with_mtp: - from tilert.models.deepseek_v3_2.dsa_mtp_e2e_show_hands import ShowHandsDsaMtpE2eLayer - - self.decode_layer = ShowHandsDsaMtpE2eLayer( - max_seq_len=self.config.max_seq_len, - ) - self.mtp_seq_len = self.decode_layer.MTP_SEQ_LEN # 4 - else: - self.decode_layer = ShowHandsDSALayer( - max_seq_len=self.config.max_seq_len, - model_path=self.model_weights_dir, - ) - - def init(self) -> None: - """Initialize the ShowHandsGenerator.""" - tilert_init() - - def cleanup(self) -> None: - """Cleanup the ShowHandsGenerator.""" - self.decode_layer.cleanup() - - def init_random_weights(self) -> None: - """Random initialize the weights.""" - self.decode_layer.init_random_weights() - - def from_pretrained(self) -> None: - """Load the model weights from the given path.""" - self.decode_layer.from_pretrained(self.model_weights_dir) - - @torch.inference_mode() - def generate(self, prompt: str, print_log: bool = True) -> tuple[str, list[float], list[int]]: - """Main function to load the model and perform single sequence generation. - - Returns: - Tuple of (result_text, time_list, accepted_counts). - accepted_counts is empty for non-MTP mode. - """ - if self.with_mtp: - return self._generate_with_mtp(prompt, print_log) - result, time_list = self._generate_without_mtp(prompt, print_log) - return result, time_list, [] # Empty accepted_counts for non-MTP - - def _generate_without_mtp(self, prompt: str, print_log: bool = True) -> tuple[str, list[float]]: - """Standard generation without MTP.""" - prompt_tokens = self.tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], add_generation_prompt=True - ) - - max_seq_len = self.config.max_seq_len - prompt_len = len(prompt_tokens) - total_len = min(max_seq_len, self.max_new_tokens + prompt_len) - - tokens = torch.full( - (self.batch_size, total_len), -1, dtype=torch.long, device=self.default_device - ) - tokens[0, :prompt_len] = torch.tensor( - prompt_tokens, dtype=torch.long, device=self.default_device - ) - prompt_mask = tokens != -1 - - prev_pos = 0 - finished = torch.tensor( - [False] * self.batch_size, dtype=torch.bool, device=self.default_device - ) - - time_list = [] - for cur_pos_val in range(1, total_len): - start_time = time.time() - multi_devices_results = self.decode_layer.forward(tokens[0, prev_pos]) - end_time = time.time() - time_list.append(end_time - start_time) - - intermediates, *_ = multi_devices_results[0] - intermediates_mapper = IntermediateMapper(list(intermediates[-TempVars.num_params() :])) - next_token = intermediates_mapper.token_out[0][0] # only the first token - - # replace the next token with the prompt token if the prompt mask is True - next_token = torch.where( - prompt_mask[0, cur_pos_val], tokens[0, cur_pos_val], next_token - ) - tokens[0, cur_pos_val] = next_token - finished |= torch.logical_and(~prompt_mask[0, cur_pos_val], next_token == self.eos_id) - prev_pos = cur_pos_val - if cur_pos_val >= prompt_len: - decoded_tokens = self.tokenizer.decode( - [next_token.item()], skip_special_tokens=True - ) - if print_log: - print(decoded_tokens, end="", flush=True) - - if finished.all(): - break - - if print_log: - print("\n") - logger.info(f"--Number of tokens generated: {len(time_list)}") - - # skip the first several samples to avoid the warmup effect - stats_time(time_list[5:], "==== Performance ====") - print("\n") - - # Reset sequence after generation, i.e. reset the cur_pos to 0 internally - self.decode_layer.reset_sequence() - - completion_tokens = [] - for _, toks in enumerate(tokens.tolist()): - toks = toks[prompt_len : prompt_len + self.max_new_tokens] - if self.eos_id in toks: - toks = toks[: toks.index(self.eos_id)] - completion_tokens.append(toks) - - decoded_tokens = self.tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) - - return f"{decoded_tokens[0]}\n" if decoded_tokens else "", time_list - - def _generate_with_mtp( - self, prompt: str, print_log: bool = True - ) -> tuple[str, list[float], list[int]]: - """Generation with MTP (Multi-Token Prediction) speculative decoding.""" - prompt_tokens = self.tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], add_generation_prompt=True - ) - - max_seq_len = self.config.max_seq_len - prompt_len = len(prompt_tokens) - total_len = min(max_seq_len, self.max_new_tokens + prompt_len) - - # Output tokens buffer - tokens = torch.full( - (self.batch_size, total_len), -1, dtype=torch.long, device=self.default_device - ) - tokens[0, :prompt_len] = torch.tensor( - prompt_tokens, dtype=torch.long, device=self.default_device - ) - - prefill_time_list = [] - decode_time_list = [] - decode_accepted_counts = [] # Only track decode phase for statistics - cur_pos = 0 # Current position in the output sequence - - # Prefill phase: process prompt tokens in chunks - # Process prompt tokens in chunks of mtp_seq_len (with overlap) - while cur_pos < prompt_len - 1: - draft_end = min(cur_pos + self.mtp_seq_len, prompt_len) - draft_tokens = tokens[0, cur_pos:draft_end].clone() - actual_token_count = draft_tokens.shape[0] - - # Pad if needed (use last token for padding) - if actual_token_count < self.mtp_seq_len: - pad_token = draft_tokens[-1].item() - padding = torch.full( - (self.mtp_seq_len - actual_token_count,), - pad_token, - dtype=torch.long, - device=self.default_device, - ) - draft_tokens = torch.cat([draft_tokens, padding]) - - draft_tokens = draft_tokens.reshape(1, self.mtp_seq_len).to(torch.int32) - - # Tell GPU how many tokens are valid (for cur_pos advancement) - self.decode_layer.set_prefill_valid_tokens(actual_token_count) - - start_time = time.time() - self.decode_layer.forward(draft_tokens) - end_time = time.time() - prefill_time_list.append(end_time - start_time) - - # Advance cur_pos by (actual_token_count - 1) to maintain overlap - # This ensures cur_pos ends at prompt_len - 1 after all chunks - cur_pos += actual_token_count - 1 - - # Decode phase: speculative decoding - # Set prefill_valid_tokens to 0 to switch to decode mode - self.decode_layer.set_prefill_valid_tokens(0) - - finished = False - while cur_pos < total_len - 1 and not finished: - # Get next_draft_tokens from previous iteration - # (or use last prompt tokens for first decode) - if cur_pos == prompt_len - 1: - # First decode iteration: use last prompt token repeated as placeholder drafts - # We can't use [t6, t7, t8, t9] because that would apply wrong RoPE positions - # (cur_pos=9 means positions 9,10,11,12, but t6 should be at position 6) - last_token = tokens[0, prompt_len - 1].item() - draft_tokens = torch.full( - (self.mtp_seq_len,), - last_token, - dtype=torch.long, - device=self.default_device, - ) - draft_tokens = draft_tokens.reshape(1, self.mtp_seq_len).to(torch.int32) - else: - # Use next_draft_tokens from previous iteration - draft_tokens = self.decode_layer.get_next_draft_tokens(0).reshape( - 1, self.mtp_seq_len - ) - - start_time = time.time() - self.decode_layer.forward(draft_tokens) - end_time = time.time() - decode_time_list.append(end_time - start_time) - - num_accepted = self.decode_layer.get_num_accepted(0) - # Use predicted_tokens for output (not next_draft_tokens which is for next iteration) - predicted_tokens = self.decode_layer.get_predicted_tokens(0).flatten() - decode_accepted_counts.append(num_accepted) - - # Add accepted tokens to output - num_output_tokens = num_accepted - for i in range(num_output_tokens): - if cur_pos + 1 + i >= total_len: - break - new_token = int(predicted_tokens[i].item()) - tokens[0, cur_pos + 1 + i] = new_token - - # Print generated token - if cur_pos + 1 + i >= prompt_len and print_log: - decoded_text = self.tokenizer.decode([new_token], skip_special_tokens=True) - print(decoded_text, end="", flush=True) - - # Check for EOS - if new_token == self.eos_id: - finished = True - break - - cur_pos += num_accepted - - if print_log: - print("\n") - total_tokens = sum(decode_accepted_counts) - logger.info(f"--Number of forward calls (decode): {len(decode_accepted_counts)}") - logger.info(f"--Total tokens generated: {total_tokens}") - if len(decode_accepted_counts) > 0: - avg_accepted = sum(decode_accepted_counts) / len(decode_accepted_counts) - min_accepted = min(decode_accepted_counts) - max_accepted = max(decode_accepted_counts) - logger.info( - f"--Accepted tokens per call: mean={avg_accepted:.2f}, " - f"min={min_accepted}, max={max_accepted}" - ) - - # Calculate correct TPS accounting for MTP's multiple tokens per call - if len(decode_time_list) > 5: - total_decode_time = sum(decode_time_list[5:]) # skip warmup - tokens_after_warmup = ( - sum(decode_accepted_counts[5:]) - if len(decode_accepted_counts) > 5 - else total_tokens - ) - effective_tps = ( - tokens_after_warmup / total_decode_time if total_decode_time > 0 else 0 - ) - avg_time_ms = total_decode_time / len(decode_time_list[5:]) * 1000 - logger.info(f"--Avg forward time: {avg_time_ms:.2f}ms") - logger.info(f"--Effective TPS (with MTP): {effective_tps:.2f} tokens/s") - - print("\n") - - # Reset sequence after generation - self.decode_layer.reset_sequence() - - # Extract completion tokens - completion_tokens = [] - for _, toks in enumerate(tokens.tolist()): - toks = toks[prompt_len : prompt_len + self.max_new_tokens] - # Remove -1 padding and tokens after EOS - toks = [t for t in toks if t != -1] - if self.eos_id in toks: - toks = toks[: toks.index(self.eos_id)] - completion_tokens.append(toks) - - decoded_tokens = self.tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) - - return ( - f"{decoded_tokens[0]}\n" if decoded_tokens else "", - decode_time_list, - decode_accepted_counts, - ) diff --git a/python/models/deepseek_v3_2/generator.py b/python/models/deepseek_v3_2/generator.py new file mode 100644 index 0000000..17252c9 --- /dev/null +++ b/python/models/deepseek_v3_2/generator.py @@ -0,0 +1,542 @@ +"""DSA show hands for deepseek v3.2.""" + +import math +import time + +import torch +from transformers import AutoTokenizer + +from tilert import logger +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.models.deepseek_v3_2.modules.end2end import ShowHandsDSALayer +from tilert.models.deepseek_v3_2.temp_var_indices import Idx +from tilert.tilert_init import tilert_init + +__all__ = [ + "DSAv32Generator", + "stats_time", +] + + +def stats_time(time_list: list[float], title: str) -> None: + if len(time_list) > 0: + avg_time = sum(time_list) / len(time_list) + std_dev = math.sqrt(sum((x - avg_time) ** 2 for x in time_list) / len(time_list)) + logger.info(title) + logger.info(f"--Average time taken to generate token: {avg_time * 1000:.4f} ms") + logger.info(f"--Standard deviation of time: {std_dev * 1000:.4f} ms") + logger.info(f"--Effective tokens per second: {1 / avg_time:.4f}") + + +class DSAv32Generator: + def __init__( + self, + model_args: ModelArgs, + max_new_tokens: int = 100, + temperature: float = 1.0, + model_weights_dir: str = "", + with_mtp: bool = False, + use_topp: bool = False, + top_p: float = 0.9, + top_k: int = 256, + sampling_seed: int = 42, + ): + """Initialize the DSAv32Generator. + + Args: + max_new_tokens: Maximum number of new tokens to generate. Defaults to 100. + temperature: Temperature for sampling. Defaults to 1.0. + model_weights_dir: Path of the model weights directory. + with_mtp: Whether to use MTP (Multi-Token Prediction) for speculative decoding. + use_topp: Whether to use top-p (nucleus) sampling instead of top-1 (argmax). + top_p: Top-p threshold for nucleus sampling. Defaults to 0.9. + top_k: Number of top-k candidates for top-p sampling. Defaults to 256. + sampling_seed: Sampling seed for top-p (fixed per request). Defaults to 42. + """ + torch.set_num_threads(64) + self.model_weights_dir = model_weights_dir + + self.max_new_tokens = max_new_tokens + self.temperature = temperature + self.with_mtp = with_mtp + self.use_topp = use_topp + self.top_p = top_p + self.top_k = top_k + self.sampling_seed = sampling_seed + + self.config = model_args + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_weights_dir, trust_remote_code=True + ) # nosec B615 + self.eos_id = self.tokenizer.eos_token_id + self.batch_size = 1 # fixed batch size to 1 for now + + self.default_device = torch.device("cuda:0") + + self.decode_layer = ShowHandsDSALayer( + model_args=self.config, + model_path=self.model_weights_dir, + with_mtp=with_mtp, + use_topp=use_topp, + top_p=top_p, + top_k=top_k, + ) + + self.mtp_seq_len = 4 if with_mtp else 1 + + def init(self) -> None: + """Initialize the ShowHandsGenerator.""" + tilert_init() + + def cleanup(self) -> None: + """Cleanup the ShowHandsGenerator.""" + self.decode_layer.cleanup() + + def init_random_weights(self) -> None: + """Random initialize the weights.""" + self.decode_layer.init_random_weights() + + def from_pretrained(self) -> None: + """Load the model weights from the given path.""" + self.decode_layer.from_pretrained(self.model_weights_dir) + + def update_sampling_params( + self, + temperature: float = 1.0, + top_p: float = 0.95, + top_k: int = 256, + use_topp: bool = True, + ) -> None: + """Update sampling parameters for the next generation.""" + self.temperature = temperature + self.use_topp = use_topp + self.top_p = top_p + self.top_k = top_k + self.decode_layer.update_sampling_config( + temperature=temperature, top_p=top_p, top_k=top_k, use_topp=use_topp + ) + + @torch.inference_mode() + def generate( + self, + prompt: str, + print_log: bool = True, + with_mtp: bool | None = None, + prompt_tokens: list[int] | None = None, + ) -> tuple[str, list[float], list[int]]: + """Main function to load the model and perform single sequence generation. + + Args: + prompt: The input prompt string. + print_log: Whether to print generation logs. + with_mtp: Override MTP mode for this call. None uses self.with_mtp. + Requires MTP weights to have been loaded (self.with_mtp=True). + prompt_tokens: Pre-tokenized prompt tokens. If provided, skip tokenization + and use these tokens directly (useful for exact-length benchmarking). + + Returns: + Tuple of (result_text, time_list, accepted_counts). + accepted_counts is empty for non-MTP mode. + """ + active_mtp = with_mtp if with_mtp is not None else self.with_mtp + if active_mtp and not self.with_mtp: + raise ValueError("Cannot use MTP mode: MTP weights were not loaded") + self.decode_layer.set_sampling_seed(self.sampling_seed, with_mtp=active_mtp) + if active_mtp: + return self._generate_with_mtp(prompt, print_log, prompt_tokens=prompt_tokens) + result, time_list = self._generate_without_mtp( + prompt, print_log, with_mtp=active_mtp, prompt_tokens=prompt_tokens + ) + return result, time_list, [] # Empty accepted_counts for non-MTP + + def _generate_without_mtp( + self, + prompt: str, + print_log: bool = True, + with_mtp: bool = False, + prompt_tokens: list[int] | None = None, + ) -> tuple[str, list[float]]: + """Standard generation without MTP.""" + if prompt_tokens is None: + prompt_tokens = self.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], add_generation_prompt=True + ) + + max_seq_len = self.config.max_seq_len + prompt_len = len(prompt_tokens) + total_len = min(max_seq_len, self.max_new_tokens + prompt_len) + + tokens = torch.full( + (self.batch_size, total_len), -1, dtype=torch.long, device=self.default_device + ) + tokens[0, :prompt_len] = torch.tensor( + prompt_tokens, dtype=torch.long, device=self.default_device + ) + prompt_mask = tokens != -1 + + prev_pos = 0 + finished = torch.tensor( + [False] * self.batch_size, dtype=torch.bool, device=self.default_device + ) + + time_list = [] + for cur_pos_val in range(1, total_len): + start_time = time.time() + multi_devices_results = self.decode_layer.forward( + tokens[0, prev_pos], with_mtp=with_mtp + ) + end_time = time.time() + time_list.append(end_time - start_time) + + intermediates, *_ = multi_devices_results[0] + next_token = intermediates[Idx.TOKEN_OUT][0][0] # only the first token + + # replace the next token with the prompt token if the prompt mask is True + next_token = torch.where( + prompt_mask[0, cur_pos_val], tokens[0, cur_pos_val], next_token + ) + tokens[0, cur_pos_val] = next_token + finished |= torch.logical_and(~prompt_mask[0, cur_pos_val], next_token == self.eos_id) + prev_pos = cur_pos_val + if cur_pos_val >= prompt_len: + decoded_tokens = self.tokenizer.decode( + [next_token.item()], skip_special_tokens=True + ) + if print_log: + print(decoded_tokens, end="", flush=True) + + if finished.all(): + break + + if print_log: + print("\n") + logger.info(f"--Number of tokens generated: {len(time_list)}") + + stats_time(time_list, "==== Performance ====") + print("\n") + + # Reset sequence after generation, i.e. reset the cur_pos to 0 internally + self.decode_layer.reset_sequence() + + completion_tokens = [] + for _, toks in enumerate(tokens.tolist()): + toks = toks[prompt_len : prompt_len + self.max_new_tokens] + if self.eos_id in toks: + toks = toks[: toks.index(self.eos_id)] + completion_tokens.append(toks) + + decoded_tokens = self.tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) + + return f"{decoded_tokens[0]}\n" if decoded_tokens else "", time_list + + def _generate_with_mtp( + self, + prompt: str, + print_log: bool = True, + prompt_tokens: list[int] | None = None, + ) -> tuple[str, list[float], list[int]]: + """Generation with MTP (Multi-Token Prediction) speculative decoding.""" + if prompt_tokens is None: + prompt_tokens = self.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], add_generation_prompt=True + ) + + max_seq_len = self.config.max_seq_len + prompt_len = len(prompt_tokens) + total_len = min(max_seq_len, self.max_new_tokens + prompt_len) + + # Output tokens buffer + tokens = torch.full( + (self.batch_size, total_len), -1, dtype=torch.long, device=self.default_device + ) + tokens[0, :prompt_len] = torch.tensor( + prompt_tokens, dtype=torch.long, device=self.default_device + ) + + prefill_time_list = [] + decode_time_list = [] + decode_accepted_counts = [] # Only track decode phase for statistics + cur_pos = 0 # Current position in the output sequence + + # Prefill phase: process prompt tokens in non-overlapping chunks. + # Each chunk fills unique KV cache positions for both main model and MTP[0]. + while cur_pos < prompt_len - 1: + draft_end = min(cur_pos + self.mtp_seq_len, prompt_len) + draft_tokens = tokens[0, cur_pos:draft_end].clone() + actual_token_count = draft_tokens.shape[0] + + # Pad if needed (use last token for padding) + if actual_token_count < self.mtp_seq_len: + pad_token = draft_tokens[-1].item() + padding = torch.full( + (self.mtp_seq_len - actual_token_count,), + pad_token, + dtype=torch.long, + device=self.default_device, + ) + draft_tokens = torch.cat([draft_tokens, padding]) + + draft_tokens = draft_tokens.reshape(1, self.mtp_seq_len).to(torch.int32) + + # Provide the extra token for MTP[0]'s shifted input last position. + # MTP[0] needs tokens[cur_pos+1 : cur_pos+mtp_seq_len+1], so the + # extra token is at cur_pos + mtp_seq_len. + mtp_extra_pos = cur_pos + self.mtp_seq_len + if mtp_extra_pos < prompt_len: + mtp_extra_token = int(tokens[0, mtp_extra_pos].item()) + else: + # Beyond prompt — use last valid draft token as padding + mtp_extra_token = int(tokens[0, draft_end - 1].item()) + self.decode_layer.set_prefill_mtp_extra_token(mtp_extra_token) + + # Tell GPU how many tokens are valid (for cur_pos advancement) + self.decode_layer.set_prefill_valid_tokens(actual_token_count) + + start_time = time.time() + self.decode_layer.forward(draft_tokens, with_mtp=True) + end_time = time.time() + prefill_time_list.append(end_time - start_time) + + # No overlap: advance by the full actual_token_count + cur_pos += actual_token_count + + # After no-overlap prefill, cur_pos may have overshot to prompt_len. + # Reset to prompt_len - 1 for correct decode start (first decode + # reprocesses the last prompt token position). + cur_pos = prompt_len - 1 + self.set_cur_pos(prompt_len - 1) + + # Decode phase: speculative decoding + # Set prefill_valid_tokens to 0 to switch to decode mode + self.decode_layer.set_prefill_valid_tokens(0) + + finished = False + while cur_pos < total_len - 1 and not finished: + # Get next_draft_tokens from previous iteration + # (or use last prompt tokens for first decode) + if cur_pos == prompt_len - 1: + # First decode iteration: use last prompt token repeated as placeholder drafts + # We can't use [t6, t7, t8, t9] because that would apply wrong RoPE positions + # (cur_pos=9 means positions 9,10,11,12, but t6 should be at position 6) + last_token = tokens[0, prompt_len - 1].item() + draft_tokens = torch.full( + (self.mtp_seq_len,), + last_token, + dtype=torch.long, + device=self.default_device, + ) + draft_tokens = draft_tokens.reshape(1, self.mtp_seq_len).to(torch.int32) + else: + # Use next_draft_tokens from previous iteration + draft_tokens = self.decode_layer.get_next_draft_tokens(0).reshape( + 1, self.mtp_seq_len + ) + + start_time = time.time() + self.decode_layer.forward(draft_tokens, with_mtp=True) + end_time = time.time() + decode_time_list.append(end_time - start_time) + + num_accepted = self.decode_layer.get_num_accepted(0) + # Use predicted_tokens for output (not next_draft_tokens which is for next iteration) + predicted_tokens = self.decode_layer.get_predicted_tokens(0).flatten() + decode_accepted_counts.append(num_accepted) + + # Add accepted tokens to output + num_output_tokens = num_accepted + for i in range(num_output_tokens): + if cur_pos + 1 + i >= total_len: + break + new_token = int(predicted_tokens[i].item()) + tokens[0, cur_pos + 1 + i] = new_token + + # Print generated token + if cur_pos + 1 + i >= prompt_len and print_log: + decoded_text = self.tokenizer.decode([new_token], skip_special_tokens=True) + print(decoded_text, end="", flush=True) + + # Check for EOS + if new_token == self.eos_id: + finished = True + break + + cur_pos += num_accepted + + if print_log: + print("\n") + total_tokens = sum(decode_accepted_counts) + logger.info(f"--Number of forward calls (decode): {len(decode_accepted_counts)}") + logger.info(f"--Total tokens generated: {total_tokens}") + if len(decode_accepted_counts) > 0: + avg_accepted = sum(decode_accepted_counts) / len(decode_accepted_counts) + min_accepted = min(decode_accepted_counts) + max_accepted = max(decode_accepted_counts) + logger.info( + f"--Accepted tokens per call: mean={avg_accepted:.2f}, " + f"min={min_accepted}, max={max_accepted}" + ) + + # Calculate correct TPS accounting for MTP's multiple tokens per call + if decode_time_list: + total_decode_time = sum(decode_time_list) + effective_tps = total_tokens / total_decode_time if total_decode_time > 0 else 0 + avg_time_ms = total_decode_time / len(decode_time_list) * 1000 + logger.info(f"--Avg forward time: {avg_time_ms:.2f}ms") + logger.info(f"--Effective TPS (with MTP): {effective_tps:.2f} tokens/s") + + print("\n") + + # Reset sequence after generation + self.decode_layer.reset_sequence() + + # Extract completion tokens + completion_tokens = [] + for _, toks in enumerate(tokens.tolist()): + toks = toks[prompt_len : prompt_len + self.max_new_tokens] + # Remove -1 padding and tokens after EOS + toks = [t for t in toks if t != -1] + if self.eos_id in toks: + toks = toks[: toks.index(self.eos_id)] + completion_tokens.append(toks) + + decoded_tokens = self.tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) + + return ( + f"{decoded_tokens[0]}\n" if decoded_tokens else "", + decode_time_list, + decode_accepted_counts, + ) + + def inject_cache( + self, + layer_caches: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], + start_pos: int = 0, + end_pos: int | None = None, + ) -> None: + """Inject external cache data into TileRT. + + This API allows injecting pre-computed KI/KV/PE cache data from an external + prefill system, enabling prefill-decode disaggregation. + + Args: + layer_caches: List of (ki, kv, pe) tuples for each layer (0 to NUM_LAYERS-1). + Each tensor should be BF16 with shape [seqlen, dim] where: + - ki: [seqlen, 128] - compressed key + - kv: [seqlen, 512] - compressed key-value + - pe: [seqlen, 64] - position encoding cache + start_pos: Start position in cache to write (0-indexed). Defaults to 0. + end_pos: End position in cache (exclusive). If None, uses seqlen from tensors. + + Example: + >>> # Load cache from external prefill system + >>> layer_caches = [] # List of 61 (ki, kv, pe) tuples + >>> for layer_id in range(61): + ... ki = load_ki_for_layer(layer_id) # [seqlen, 128] bf16 + ... kv = load_kv_for_layer(layer_id) # [seqlen, 512] bf16 + ... pe = load_pe_for_layer(layer_id) # [seqlen, 64] bf16 + ... layer_caches.append((ki, kv, pe)) + >>> generator.inject_cache(layer_caches, start_pos=0) + >>> generator.set_cur_pos(seqlen) # Set RoPE position + >>> # Continue generation from cache + """ + num_layers = len(layer_caches) + if num_layers == 0: + logger.warning("inject_cache called with empty layer_caches") + return + + # Infer seqlen from first tensor if end_pos not specified + first_ki, _, _ = layer_caches[0] + seqlen = first_ki.size(0) + if end_pos is None: + end_pos = start_pos + seqlen + + cache_len = end_pos - start_pos + logger.info(f"Injecting cache: {num_layers} layers, positions [{start_pos}, {end_pos})") + + num_devices = self.decode_layer.num_devices + + for device_id in range(num_devices): + _, caches, _, _ = self.decode_layer._get_device_result(device_id) + + for layer_id, (ki, kv, pe) in enumerate(layer_caches): + if layer_id >= num_layers: + logger.warning(f"Layer index {layer_id} is out of bounds, skipping.") + break + + base_idx = layer_id * 3 + + # Copy to device and inject into cache + # Cache layout: [batch=1, max_seq_len, dim] + # External data: [seqlen, dim] + ki_src = ki[:cache_len].to(f"cuda:{device_id}") + kv_src = kv[:cache_len].to(f"cuda:{device_id}") + pe_src = pe[:cache_len].to(f"cuda:{device_id}") + + caches[base_idx + 0][0, start_pos:end_pos, :].copy_(ki_src) + caches[base_idx + 1][0, start_pos:end_pos, :].copy_(kv_src) + caches[base_idx + 2][0, start_pos:end_pos, :].copy_(pe_src) + + logger.info(f"Cache injection completed for {num_devices} devices") + + def set_cur_pos(self, cur_pos: int) -> None: + """Set the current position for RoPE in C++ backend. + + This should be called after inject_cache() to ensure the C++ global + g_cur_pos matches the injected cache length. This is critical for + correct RoPE position encoding during continued generation. + + For MTP mode, sets the GPU tensor at intermediates[31] directly. + For non-MTP mode, calls the C++ dsa_show_hands_set_cur_pos API. + + Args: + cur_pos: The current sequence position (typically the length of prefilled tokens). + + Example: + >>> generator.inject_cache(layer_caches, start_pos=0) + >>> generator.set_cur_pos(prefill_len) # Set position to prefill length + >>> # Now generate continues from the correct position + """ + if self.with_mtp: + # MTP E2E uses g_cur_pos_tensors which is the GPU tensor + num_devices = self.decode_layer.num_devices + for device_id in range(num_devices): + intermediates, _, _, _ = self.decode_layer._get_device_result(device_id) + cur_pos_tensor = intermediates[Idx.CUR_POS] + cur_pos_tensor.fill_(cur_pos) + else: + # Non-MTP uses the C++ global g_cur_pos + torch.ops.tilert.dsa_show_hands_set_cur_pos(cur_pos) + + def inject_last_hidden_state(self, last_hidden_state: torch.Tensor) -> None: + """Inject the last hidden state for MTP mode. + + For MTP (Multi-Token Prediction), the MTP preprocess layer needs the + last hidden state from the main model's last token. This method injects + the hidden state into intermediates[33] (last_hidden_states slot). + + Args: + last_hidden_state: [hidden_size] or [1, hidden_size] BF16 tensor. + The hidden state of the last token from prefill. + + Example: + >>> # After inject_cache, inject the last hidden state for MTP + >>> generator.inject_last_hidden_state(last_hidden_state) + >>> # Then set cur_pos and start generation + """ + if not self.with_mtp: + logger.warning("inject_last_hidden_state called but with_mtp is False, skipping") + return + + # Normalize shape to [1, hidden_size] + if last_hidden_state.dim() == 1: + last_hidden_state = last_hidden_state.unsqueeze(0) + + num_devices = self.decode_layer.num_devices + for device_id in range(num_devices): + intermediates, _, _, _ = self.decode_layer._get_device_result(device_id) + # Shape: [batch=1, seq=4, hidden_size], we set seq[0] since it's the last token + lhs_tensor = intermediates[Idx.LAST_HIDDEN_STATES] + lhs_src = last_hidden_state.to(f"cuda:{device_id}") + lhs_tensor[0, 0, :].copy_(lhs_src.squeeze(0)) + + logger.info(f"Injected last_hidden_state to {num_devices} devices") diff --git a/python/models/deepseek_v3_2/model_args.py b/python/models/deepseek_v3_2/model_args.py index c38cd3b..b149edf 100644 --- a/python/models/deepseek_v3_2/model_args.py +++ b/python/models/deepseek_v3_2/model_args.py @@ -14,6 +14,7 @@ class ModelArgs: Data class for defining model arguments and hyperparameters. Attributes: + arch_name (str): Architecture name. max_batch_size (int): Maximum batch size. max_seq_len (int): Maximum sequence length. dtype (Literal["bf16", "fp8"]): Data type for computations. @@ -37,25 +38,27 @@ class ModelArgs: qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. v_head_dim (int): Dimension for value projections. - original_seq_len (int): Original sequence length. + original_seq_len (Optional[int]): Original sequence length. rope_theta (float): Base for rotary positional encoding. - rope_factor (float): Scaling factor for extended sequence lengths. - beta_fast (int): Fast beta correction factor. - beta_slow (int): Slow beta correction factor. + rope_factor (Optional[float]): Scaling factor for extended sequence lengths. + beta_fast (Optional[int]): Fast beta correction factor. + beta_slow (Optional[int]): Slow beta correction factor. mscale (float): Scaling factor for extended attention. index_head_dim (int): Dimension for index head. index_topk (int): Top-k for index head. """ + arch_name = "deepseek_v3_2" + max_batch_size: int = 1 # NOTE: the current implementation only supports a batch size being 1 - max_seq_len: int = 4096 * 4 - dtype: Literal["bf16", "fp8"] = "bf16" + max_seq_len: int = 160 * 1024 # 160K + dtype: Literal["bf16", "fp8"] = "fp8" scale_fmt: str | None = None vocab_size: int = 129280 dim: int = 7168 inter_dim: int = 18432 - moe_inter_dim: int = 2048 // 8 + moe_inter_dim: int = 2048 n_layers: int = 61 n_dense_layers: int = 3 n_heads: int = 128 @@ -67,7 +70,7 @@ class ModelArgs: n_expert_groups: int = 8 n_limited_groups: int = 4 score_func: Literal["softmax", "sigmoid"] = "softmax" - route_scale: float = 1.0 + route_scale: float = 2.5 # mla q_lora_rank: int = 1536 @@ -77,14 +80,21 @@ class ModelArgs: v_head_dim: int = 128 # yarn - original_seq_len: int = 4096 + original_seq_len: int | None = 4096 rope_theta: float = 10000.0 - rope_factor: float = 40 - beta_fast: int = 32 - beta_slow: int = 1 + rope_factor: float | None = 40 + beta_fast: int | None = 32 + beta_slow: int | None = 1 mscale: float = 1.0 # index index_n_heads: int = 64 index_head_dim: int = 128 index_topk: int = 2048 + + kv_cache_pad: int = 8 + + # quant + block_size: int = 128 + + eps: float = 1e-6 diff --git a/python/models/deepseek_v3_2/modules/__init__.py b/python/models/deepseek_v3_2/modules/__init__.py new file mode 100644 index 0000000..937085b --- /dev/null +++ b/python/models/deepseek_v3_2/modules/__init__.py @@ -0,0 +1,11 @@ +"""DeepSeek v3.2 high-level Python modules (MLA, MLP, MTP, etc.).""" + +__all__ = [ + "dsa", + "end2end", + "mla", + "mlp", + "moe", + "mtp", + "mtp_preprocess", +] diff --git a/python/models/deepseek_v3_2/modules/dsa.py b/python/models/deepseek_v3_2/modules/dsa.py new file mode 100644 index 0000000..64116f9 --- /dev/null +++ b/python/models/deepseek_v3_2/modules/dsa.py @@ -0,0 +1,154 @@ +from typing import Any + +import torch + +from tilert.models.base import SerializableTileRTModule +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.models.deepseek_v3_2.modules.mlp import MlpBlock +from tilert.models.deepseek_v3_2.modules.moe import MoeBlock +from tilert.models.deepseek_v3_2.ops import RMSNormHeadProj +from tilert.models.deepseek_v3_2.temp_var_indices import TEMP_VARS_SIZE, Idx + + +class Dsa(SerializableTileRTModule): + """DSA module.""" + + def __init__(self, model_args: ModelArgs, device_id: int, num_devices: int): + super().__init__( + model_args=model_args, + device_id=device_id, + num_devices=num_devices, + remove_selected=True, + ) + + for layer_idx in range(model_args.n_layers): + block_type = MlpBlock if layer_idx < model_args.n_dense_layers else MoeBlock + block = block_type(model_args=model_args, device_id=device_id, num_devices=num_devices) + self.register_op(block, prefix=f"layer_{layer_idx}_", suffix=f"_dev_{device_id}") + + self.register_op( + RMSNormHeadProj(model_args=model_args, device_id=device_id, num_devices=num_devices), + prefix=f"layer_{model_args.n_layers}_", + suffix=f"_dev_{device_id}", + retain_weights=True, + ) + + self.embed_tokens_weight = None + self.freqs_cis = None + + def init_tilert_weights(self, state_dicts: dict[str, torch.Tensor]) -> None: + super().init_tilert_weights(state_dicts) + self.embed_tokens_weight = state_dicts["model.embed_tokens.weight"] + self.freqs_cis = state_dicts["freqs_cis"] + + def get_weights_list(self) -> list[torch.Tensor]: + return [*super().get_weights_list(), self.embed_tokens_weight, self.freqs_cis] + + def get_temp_vars( + self, batch_size: int, seq_len: int, extra_args: dict[str, Any] | None = None + ) -> list[torch.Tensor]: + bf16_desc = {"dtype": torch.bfloat16, "device": f"cuda:{self.device_id}"} + fp32_desc = {"dtype": torch.float32, "device": f"cuda:{self.device_id}"} + int32_desc = {"dtype": torch.int32, "device": f"cuda:{self.device_id}"} + int64_desc = {"dtype": torch.int64, "device": f"cuda:{self.device_id}"} + fp8_desc = {"dtype": torch.float8_e4m3fn, "device": f"cuda:{self.device_id}"} + + assert extra_args is not None + temperature = extra_args["temperature"] + top_p = extra_args["top_p"] + top_k = extra_args["top_k"] + use_topp = extra_args["use_topp"] + + dim = self.model_args.dim + batch_seq = (batch_size, seq_len) + q_lora_rank = self.model_args.q_lora_rank + kv_lora_rank = self.model_args.kv_lora_rank + qk_nope_head_dim = self.model_args.qk_nope_head_dim + n_local_heads = self.model_args.n_heads // self.num_devices + qk_rope_head_dim = self.model_args.qk_rope_head_dim + index_head_dim = self.model_args.index_head_dim + v_head_dim = self.model_args.v_head_dim + n_index_heads = self.model_args.index_n_heads + max_seq_len = self.model_args.max_seq_len + index_topk = self.model_args.index_topk + n_routed_experts = self.model_args.n_routed_experts + n_activated_experts = self.model_args.n_activated_experts + n_total_experts = self.model_args.n_activated_experts + self.model_args.n_shared_experts + moe_inter_dim = self.model_args.moe_inter_dim // self.num_devices + vocab_size = self.model_args.vocab_size // self.num_devices + + temp_vars: list[torch.Tensor | None] = [None] * TEMP_VARS_SIZE + + temp_vars[Idx.Q] = torch.zeros(*batch_seq, q_lora_rank, **bf16_desc) + temp_vars[Idx.KV] = torch.zeros(*batch_seq, kv_lora_rank, **bf16_desc) + temp_vars[Idx.KI] = torch.zeros(*batch_seq, index_head_dim, **bf16_desc) + temp_vars[Idx.Q_NOPE_DOWN] = torch.zeros( + *batch_seq, n_local_heads, qk_nope_head_dim, **bf16_desc + ) + temp_vars[Idx.Q_PE] = torch.zeros(*batch_seq, n_local_heads, qk_rope_head_dim, **bf16_desc) + temp_vars[Idx.IQ] = torch.zeros(*batch_seq, n_index_heads, index_head_dim, **bf16_desc) + temp_vars[Idx.IQ_RT] = torch.zeros(*batch_seq, n_index_heads, index_head_dim, **bf16_desc) + temp_vars[Idx.IDX_SCORES] = torch.zeros(*batch_seq, n_index_heads, **bf16_desc) + temp_vars[Idx.IDX_LOGITS] = torch.zeros( + *batch_seq, max_seq_len + self.model_args.kv_cache_pad, **fp32_desc + ) + temp_vars[Idx.IDX_SELECTS] = torch.zeros(*batch_seq, index_topk, **int32_desc) + temp_vars[Idx.Q_NOPE] = torch.zeros(*batch_seq, n_local_heads, kv_lora_rank, **bf16_desc) + temp_vars[Idx.O] = torch.zeros(*batch_seq, n_local_heads, kv_lora_rank, **bf16_desc) + temp_vars[Idx.O_ACC] = torch.zeros(*batch_seq, n_local_heads, 32, kv_lora_rank, **fp32_desc) + temp_vars[Idx.O_LSE] = torch.empty(*batch_seq, n_local_heads, **fp32_desc) + temp_vars[Idx.O_LSE_ACC] = torch.empty(*batch_seq, n_local_heads, 32, **fp32_desc) + temp_vars[Idx.PROJ_O] = torch.zeros(*batch_seq, n_local_heads, v_head_dim, **bf16_desc) + temp_vars[Idx.UNPROJ_O] = torch.zeros(*batch_seq, dim, **bf16_desc) + temp_vars[Idx.SCORES] = torch.zeros(*batch_seq, n_routed_experts, **fp32_desc) + temp_vars[Idx.X_MLP_IN] = torch.zeros(*batch_seq, dim, **bf16_desc) + exp_up_gate = torch.zeros(*batch_seq, n_total_experts, moe_inter_dim, **bf16_desc) + temp_vars[Idx.UP_GATE] = exp_up_gate + temp_vars[Idx.SEL_PROBS] = torch.zeros(*batch_seq, n_activated_experts, **fp32_desc) + temp_vars[Idx.SEL_INDICES] = torch.zeros(*batch_seq, n_activated_experts, **int32_desc) + temp_vars[Idx.EXP_OUT] = torch.zeros(*batch_seq, dim, **bf16_desc) + temp_vars[Idx.X_RMSNORM] = torch.zeros(*batch_seq, dim, **bf16_desc) + temp_vars[Idx.LOGITS_OUT] = torch.zeros(*batch_seq, vocab_size, **fp32_desc) + temp_vars[Idx.TOKEN_OUT] = torch.zeros(*batch_seq, 1, **int32_desc) + + temp_vars[Idx.EMBEDDING_RMSNORM] = torch.zeros(*batch_seq, dim, **bf16_desc) + temp_vars[Idx.HIDDEN_RMSNORM] = torch.zeros(*batch_seq, dim, **bf16_desc) + temp_vars[Idx.EH_PROJ] = torch.zeros(*batch_seq, dim, **bf16_desc) + temp_vars[Idx.X_TENSOR] = torch.zeros(*batch_seq, dim, **bf16_desc) + temp_vars[Idx.ROPE_FREQS] = torch.zeros(*batch_seq, qk_rope_head_dim, **fp32_desc) + temp_vars[Idx.CUR_POS] = torch.zeros(batch_size, **int32_desc) + temp_vars[Idx.TOKEN_ID] = torch.zeros(*batch_seq, 1, **int32_desc) + temp_vars[Idx.LAST_HIDDEN_STATES] = torch.zeros(*batch_seq, dim, **bf16_desc) + + temp_vars[Idx.DRAFT_TOKENS] = torch.zeros(*batch_seq, **int32_desc) + temp_vars[Idx.PREDICTED_TOKENS] = torch.zeros(*batch_seq, 1, **int32_desc) + temp_vars[Idx.PREDICTED_HIDDEN] = torch.zeros(*batch_seq, dim, **bf16_desc) + temp_vars[Idx.ACCEPTED_TOKENS] = torch.zeros(batch_size, **int32_desc) + temp_vars[Idx.NEXT_DRAFT_TOKENS] = torch.zeros(*batch_seq, **int32_desc) + + temp_vars[Idx.X_QUANT] = torch.zeros(*batch_seq, dim, **fp8_desc) + temp_vars[Idx.X_SCALE] = torch.zeros( + *batch_seq, dim // self.model_args.block_size, **fp32_desc + ) + temp_vars[Idx.MOE_UP_GATE] = torch.zeros_like(exp_up_gate) + + # temp_vars[Idx.IDX_SEL_WS] = torch.zeros(*batch_seq, 4, index_topk * 2, **int32_desc) + temp_vars[Idx.IDX_SEL_WS] = torch.zeros(*batch_seq, (200 * 1024 + 258), **int32_desc) + + temp_vars[Idx.MTP0_TOKEN_OUT] = torch.zeros(*batch_seq, 1, **int32_desc) + temp_vars[Idx.MTP1_TOKEN_OUT] = torch.zeros(*batch_seq, 1, **int32_desc) + temp_vars[Idx.MTP0_EXP_OUT] = torch.zeros(*batch_seq, dim, **bf16_desc) + + temp_vars[Idx.SAMPLING_SEED] = torch.zeros(*batch_seq, **int64_desc) + temp_vars[Idx.SAMPLING_POSITIONS] = torch.zeros(*batch_seq, **int64_desc) + temp_vars[Idx.SAMPLING_CONFIG] = torch.tensor( + [temperature, top_p, top_k, use_topp], **fp32_desc + ) + temp_vars[Idx.TOP_P_SCORES] = torch.zeros(*batch_seq, **fp32_desc) + temp_vars[Idx.TOP_P_DEBUG] = torch.zeros(*batch_seq, vocab_size, **fp32_desc) + + for i, t in enumerate(temp_vars): + if t is None: + raise RuntimeError(f"temp_vars[{i}] ({Idx(i).name}) was not initialized") + + return temp_vars # type: ignore[return-value] diff --git a/python/models/deepseek_v3_2/modules/end2end.py b/python/models/deepseek_v3_2/modules/end2end.py new file mode 100644 index 0000000..47a5671 --- /dev/null +++ b/python/models/deepseek_v3_2/modules/end2end.py @@ -0,0 +1,520 @@ +"""DSA show hands for deepseek v3.2.""" + +import json +import os +import sys +import threading +import time +from typing import Any + +import torch +from safetensors.torch import load_file + +from tilert import logger +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.models.deepseek_v3_2.modules.dsa import Dsa +from tilert.models.deepseek_v3_2.modules.mtp import MTP +from tilert.models.deepseek_v3_2.temp_var_indices import Idx, validate_temp_vars_layout +from tilert.models.utils import precompute_freqs_cis +from tilert.utils import get_profile_log_tensor + +__all__ = ["ShowHandsDSALayer"] + + +DeviceResult = tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], torch.Tensor] + + +def dsa_show_hands_prepare_money( + params: list[torch.Tensor], + temp_vars: list[torch.Tensor], + cache_vars: list[torch.Tensor], + profile_logs: torch.Tensor, + forward_max_seq_len: int, + with_mtp: bool = False, + is_glm5: bool = False, +) -> Any: + """Prepare money for show hands""" + mtp_flag = "_mtp_e2e" if with_mtp else "" + glm5_flag = "_glm5" if is_glm5 else "" + func_name = f"dsa{mtp_flag}_show_hands_prepare_money{glm5_flag}" + if mtp_flag: + return getattr(torch.ops.tilert, func_name)(params, temp_vars, cache_vars, profile_logs) + return getattr(torch.ops.tilert, func_name)( + params, temp_vars, cache_vars, profile_logs, forward_max_seq_len + ) + + +def dsa_show_hands(token_id: torch.Tensor, with_mtp: bool = False, is_glm5: bool = False) -> Any: + """Show hands with native MT""" + mtp_flag = "_mtp_e2e" if with_mtp else "" + glm5_flag = "_glm5" if is_glm5 else "" + func_name = f"dsa{mtp_flag}_show_hands{glm5_flag}" + return getattr(torch.ops.tilert, func_name)(token_id) + + +def dsa_show_hands_reset(with_mtp: bool = False, is_glm5: bool = False) -> Any: + """Reset show one hand""" + mtp_flag = "_mtp_e2e" if with_mtp else "" + glm5_flag = "_glm5" if is_glm5 else "" + func_name = f"dsa{mtp_flag}_show_hands_reset{glm5_flag}" + return getattr(torch.ops.tilert, func_name)() + + +def dsa_show_hands_go_home(with_mtp: bool = False, is_glm5: bool = False) -> Any: + """Go home""" + mtp_flag = "_mtp_e2e" if with_mtp else "" + glm5_flag = "_glm5" if is_glm5 else "" + func_name = f"dsa{mtp_flag}_show_hands_go_home{glm5_flag}" + return getattr(torch.ops.tilert, func_name)() + + +def dsa_show_hands_set_sampling_seed( + seed: int, with_mtp: bool = False, is_glm5: bool = False +) -> Any: + """Set the sampling seed (request-level, fixed for the entire request). + + Args: + seed: The sampling seed value. + """ + mtp_flag = "_mtp_e2e" if with_mtp else "" + glm5_flag = "_glm5" if is_glm5 else "" + func_name = f"dsa{mtp_flag}_show_hands_set_sampling_seed{glm5_flag}" + return getattr(torch.ops.tilert, func_name)(seed) + + +def dsa_mtp_e2e_show_hands_set_prefill_valid_tokens( + num_valid_tokens: int, is_glm5: bool = False +) -> Any: + """Set the number of valid (non-padding) tokens for prefill mode. + + This controls how many tokens are copied from draft_tokens to predicted_tokens + during prefill. Should be called before forward() when the chunk has padding. + + Args: + num_valid_tokens: Number of valid tokens in the chunk (1-4). + """ + mtp_flag = "_mtp_e2e" + glm5_flag = "_glm5" if is_glm5 else "" + func_name = f"dsa{mtp_flag}_show_hands_set_prefill_valid_tokens{glm5_flag}" + return getattr(torch.ops.tilert, func_name)(num_valid_tokens) + + +def dsa_mtp_e2e_show_hands_set_prefill_mtp_extra_token(token: int, is_glm5: bool = False) -> Any: + """Set the extra token for MTP[0] shifted input during prefill. + + This is the prompt token at (cur_pos + mtp_seq_len), used as the last position + of MTP[0]'s shifted input to enable no-overlap prefill chunking. + + Args: + token: The extra prompt token id (int32). + """ + mtp_flag = "_mtp_e2e" + glm5_flag = "_glm5" if is_glm5 else "" + func_name = f"dsa{mtp_flag}_show_hands_set_prefill_mtp_extra_token{glm5_flag}" + return getattr(torch.ops.tilert, func_name)(token) + + +class ShowHandsDSALayer: + """Show hands DSA for deepseek v3.2.""" + + def __init__( + self, + model_args: ModelArgs, + model_path: str = "", + with_weight_conversion: bool = True, + with_mtp: bool = False, + temperature: float = 1.0, + top_p: float = 0.9, + top_k: int = 256, + use_topp: bool = False, + ) -> None: + validate_temp_vars_layout() + print(f"Model args: {model_args.arch_name}") + for k_arg, v_arg in model_args.__dict__.items(): + print(f" - {k_arg}: {v_arg}") + self.model_args = model_args + self.is_glm5 = self.model_args.arch_name == "glm_5" + assert self.model_args.arch_name in ["deepseek_v3_2", "glm_5"] + + self.num_devices = 8 + self.forward_max_seq_len = 4 + + self.model_path = model_path + self.with_weight_conversion = with_weight_conversion + self.with_mtp = with_mtp + + self.multi_devices_results: list[DeviceResult | None] = [None] * torch.cuda.device_count() + + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.use_topp = use_topp + + def _gen_freqs_cis(self) -> torch.Tensor: + freqs_cis = precompute_freqs_cis(self.model_args) + return torch.view_as_real(freqs_cis).reshape(freqs_cis.shape[0], -1) + + def load_device_weights( + self, model_path: str, device_id: int, extra_keys: list + ) -> dict[str, torch.Tensor]: + index_file = "model.safetensors.index.json" + with open(os.path.join(model_path, index_file), encoding="utf-8") as f: + weights_index = json.load(f) + weight_file_map = weights_index["weight_map"] + + weights_list = [_k for _k in weight_file_map.keys() if _k.endswith(f"dev_{device_id}")] + weights_list = [*weights_list, *extra_keys] + + target_files = set() + for weight_key in weights_list: + weight_file = weight_file_map[weight_key] + target_files.add(weight_file) + + state_dicts = {} + for weight_file in target_files: + logger.info(f"Loading weights from {weight_file} for device {device_id}") + state_dict = load_file( + os.path.join(model_path, weight_file), device=f"cuda:{device_id}" + ) + state_dicts.update(state_dict) + del state_dict + torch.cuda.empty_cache() + + state_dicts["freqs_cis"] = self._gen_freqs_cis().to(device_id) + return state_dicts + + def update_sampling_config( + self, temperature: float, top_p: float, top_k: int, use_topp: bool = True + ) -> None: + """Update sampling config, re-capturing CUDA graphs if parameters changed. + + Sampling parameters are baked into CUDA graph instructions at prepare_money + time, so any change requires a full teardown + re-capture cycle. + """ + new_config = (temperature, top_p, top_k, use_topp) + current_config = (self.temperature, self.top_p, self.top_k, self.use_topp) + if new_config == current_config: + return + + print( + f"Recapturing CUDA graphs: " + f"temperature={temperature}, top_p={top_p}, top_k={top_k}, use_topp={use_topp}" + ) + + # Teardown: stop all threads and unregister all modules + if self.with_mtp: + dsa_show_hands_go_home(True, self.is_glm5) + dsa_show_hands_go_home(False, self.is_glm5) + else: + dsa_show_hands_go_home(False, self.is_glm5) + + # Store new config + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.use_topp = use_topp + + # Update sampling_config tensor on all devices + for device_id in range(self.num_devices): + result = self.multi_devices_results[device_id] + if result is not None: + intermediates = result[0] + intermediates[Idx.SAMPLING_CONFIG].copy_( + torch.tensor( + [temperature, top_p, float(top_k), 1.0 if use_topp else 0.0], + dtype=torch.float32, + device=f"cuda:{device_id}", + ) + ) + + # Re-prepare all modules (re-captures CUDA graphs with new config) + for device_id in range(self.num_devices): + with torch.cuda.device(device_id): + intermediates, caches, params, profile_logs = self._get_device_result(device_id) + dsa_show_hands_prepare_money( + params, + intermediates, + caches, + profile_logs, + self.forward_max_seq_len, + self.with_mtp, + self.is_glm5, + ) + if self.with_mtp: + dsa_show_hands_prepare_money( + params[: self._base_params_count], + intermediates, + caches[: self._base_caches_count], + profile_logs, + self.forward_max_seq_len, + False, + self.is_glm5, + ) + + @staticmethod + def tot_size_in_bytes_aligned(temp_vars: list[torch.Tensor], aligned_size: int) -> int: + tot_size: int = 0 + for param in temp_vars: + aligned_param_size = (param.nbytes + aligned_size - 1) // aligned_size * aligned_size + tot_size += aligned_param_size + return tot_size + + def generate_params_with_continuous_storage( + self, temp_vars: list[torch.Tensor], device: torch.device, aligned_size: int = 1024 + ) -> list[torch.Tensor]: + tot_size = self.tot_size_in_bytes_aligned(temp_vars, aligned_size) + cloned_params = [] + large_tensor = torch.zeros(tot_size, device=device, dtype=torch.uint8) + offset = 0 + for param in temp_vars: + aligned_param_size = (param.nbytes + aligned_size - 1) // aligned_size * aligned_size + cloned_params.append( + large_tensor[offset : offset + param.nbytes].view(param.dtype).view(param.shape) + ) + offset += aligned_param_size + return cloned_params + + def _init_weights(self, model_path: str | None) -> None: + """Load the model weights from the given path or generate random weights.""" + + def __load_weights(device_id: int, model_path: str | None) -> None: + intermediates: list[torch.Tensor] = [] + caches: list[torch.Tensor] = [] + params: list[torch.Tensor] = [] + state_dicts = {} + start_time = time.time() + with torch.cuda.device(device_id): + assert model_path is not None # Type narrowing for mypy + # state_dicts = _load_state_dicts(model_path, dev_attrs) + state_dicts = self.load_device_weights( + model_path, + device_id, + [ + "model.embed_tokens.weight", + f"layer_{self.model_args.n_layers}_lm_head.weight_dev_{device_id}", + f"layer_{self.model_args.n_layers}_model.norm.weight_dev_{device_id}", + ], + ) + + dsa = Dsa(self.model_args, device_id, self.num_devices) + dsa.init_tilert_weights(state_dicts) + params.extend(dsa.get_weights_list()) + caches.extend(dsa.get_cache_vars()) + intermediates.extend( + self.generate_params_with_continuous_storage( + dsa.get_temp_vars( + 1, + self.forward_max_seq_len, + { + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "use_topp": self.use_topp, + }, + ), + device_id, + ) + ) + + # generate_params_with_continuous_storage creates zero-filled views. + # Populate sampling_config with actual values. + sampling_config = intermediates[Idx.SAMPLING_CONFIG] + sampling_config.copy_( + torch.tensor( + [ + self.temperature, + self.top_p, + float(self.top_k), + 1.0 if self.use_topp else 0.0, # 0=top1(default), 1=topp + ], + dtype=torch.float32, + device=device_id, + ) + ) + + # Track base (non-MTP) params/caches count for dual-module init + base_params_count = len(params) + base_caches_count = len(caches) + + # Add MTP-specific params when with_mtp is True + if self.with_mtp: + mtp = MTP(self.model_args, device_id, self.num_devices) + mtp.init_tilert_weights(state_dicts) + params.extend(mtp.get_weights_list()) + caches.extend(mtp.get_cache_vars()) + logger.info(f"Loaded real MTP weights for device {device_id}") + + profile_logs = get_profile_log_tensor(device=device_id, num_max_insts=65536) + result = (intermediates, caches, params, profile_logs) + self.multi_devices_results[device_id] = result + self._base_params_count = base_params_count + self._base_caches_count = base_caches_count + + del state_dicts + torch.cuda.empty_cache() + elapsed_time = time.time() - start_time + minutes = int(elapsed_time // 60) + seconds = int(elapsed_time % 60) + time_str = ( + f"{minutes} minutes {seconds} seconds" if minutes > 0 else f"{seconds} seconds" + ) + logger.info(f"Completed loading weights for device {device_id} in {time_str}") + + threads = [] + exceptions: list[Exception | None] = [None] * self.num_devices + for device_id in range(self.num_devices): + + def _runner(dev_id: int) -> None: + try: + __load_weights(dev_id, model_path) + except Exception as exc: # pragma: no cover - surfaced after join + exceptions[dev_id] = exc + + thread = threading.Thread(target=_runner, args=(device_id,)) + threads.append(thread) + thread.start() + for thread in threads: + thread.join() + for device_id, exc in enumerate(exceptions): + if exc is not None: + raise RuntimeError(f"Failed to initialize device {device_id}: {exc}") from exc + + # Prepare money for all devices + for device_id in range(self.num_devices): + with torch.cuda.device(device_id): + intermediates, caches, params, profile_logs = self._get_device_result(device_id) + # Always prepare the primary module (MTP if with_mtp, else non-MTP) + dsa_show_hands_prepare_money( + params, + intermediates, + caches, + profile_logs, + self.forward_max_seq_len, + self.with_mtp, + self.is_glm5, + ) + # When MTP-capable, also prepare the non-MTP module using base params/caches + if self.with_mtp: + dsa_show_hands_prepare_money( + params[: self._base_params_count], + intermediates, + caches[: self._base_caches_count], + profile_logs, + self.forward_max_seq_len, + False, + self.is_glm5, + ) + + def from_pretrained(self, model_path: str) -> None: + """Load the model weights from the given path.""" + if not os.path.exists(model_path): + raise ValueError(f"Model weights directory {model_path} does not exist") + self._init_weights(model_path) + + def init_random_weights(self) -> None: + """Generate random weights.""" + self._init_weights(None) + + def forward( + self, + token_id: torch.Tensor, + with_mtp: bool | None = None, + ) -> list[DeviceResult]: + active_mtp = with_mtp if with_mtp is not None else self.with_mtp + dsa_show_hands(token_id.cpu(), active_mtp, self.is_glm5) + return [self._get_device_result(device_id) for device_id in range(self.num_devices)] + + def set_sampling_seed(self, seed: int, with_mtp: bool | None = None) -> None: + """Set the sampling seed for top-p sampling. + + The seed is fixed for the entire request. Position provides per-step variation. + + Args: + seed: The sampling seed value. + with_mtp: Override MTP mode for this call. Defaults to self.with_mtp. + """ + active_mtp = with_mtp if with_mtp is not None else self.with_mtp + dsa_show_hands_set_sampling_seed(seed, active_mtp, self.is_glm5) + + def reset_sequence(self) -> None: + if self.with_mtp: + # Reset both MTP and non-MTP modules for clean state + dsa_show_hands_reset(True, self.is_glm5) + dsa_show_hands_reset(False, self.is_glm5) + else: + dsa_show_hands_reset(False, self.is_glm5) + + def cleanup(self) -> None: + if self.with_mtp: + # Cleanup both MTP and non-MTP modules + dsa_show_hands_go_home(True, self.is_glm5) + dsa_show_hands_go_home(False, self.is_glm5) + else: + dsa_show_hands_go_home(False, self.is_glm5) + + def __del__(self) -> None: + try: + self.cleanup() + except Exception as e: + print(f"Exception during cleanup: {e}", file=sys.stderr) + + def _get_device_result(self, device_id: int) -> DeviceResult: + device_result = self.multi_devices_results[device_id] + if device_result is None: + raise RuntimeError(f"Device {device_id} is not initialized") + return device_result + + def set_prefill_valid_tokens(self, num_valid_tokens: int) -> None: + """Set the number of valid tokens for prefill mode. + + This controls how many tokens are copied from draft_tokens to predicted_tokens + during prefill. Should be called before forward() when the chunk has padding. + + Args: + num_valid_tokens: Number of valid tokens in the chunk (1-4). + """ + dsa_mtp_e2e_show_hands_set_prefill_valid_tokens(num_valid_tokens, self.is_glm5) + + def set_prefill_mtp_extra_token(self, token: int) -> None: + """Set the extra token for MTP[0] shifted input during prefill. + + Args: + token: The prompt token at (cur_pos + mtp_seq_len). + """ + dsa_mtp_e2e_show_hands_set_prefill_mtp_extra_token(token, self.is_glm5) + + def get_next_draft_tokens(self, device_id: int = 0) -> torch.Tensor: + """Get next_draft_tokens from the specified device. + + Args: + device_id: Device ID to get results from. + + Returns: + next_draft_tokens tensor of shape [1, MTP_SEQ_LEN]. + """ + intermediates, _, _, _ = self._get_device_result(device_id) + return intermediates[Idx.NEXT_DRAFT_TOKENS] + + def get_num_accepted(self, device_id: int = 0) -> int: + """Get number of accepted tokens from the specified device. + + Args: + device_id: Device ID to get results from. + + Returns: + Number of accepted tokens. + """ + intermediates, _, _, _ = self._get_device_result(device_id) + return int(intermediates[Idx.ACCEPTED_TOKENS][0].item()) + + def get_predicted_tokens(self, device_id: int = 0) -> torch.Tensor: + """Get predicted_tokens from the specified device. + + Args: + device_id: Device ID to get results from. + + Returns: + predicted_tokens tensor containing main model predictions. + """ + intermediates, _, _, _ = self._get_device_result(device_id) + return intermediates[Idx.PREDICTED_TOKENS] diff --git a/python/models/deepseek_v3_2/modules/mla.py b/python/models/deepseek_v3_2/modules/mla.py new file mode 100644 index 0000000..6f1d138 --- /dev/null +++ b/python/models/deepseek_v3_2/modules/mla.py @@ -0,0 +1,107 @@ +import torch + +from tilert.models.base import SerializableTileRTModule +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.models.deepseek_v3_2.ops.layernorm_rope_rotate import LayerNormRoPERotate +from tilert.models.deepseek_v3_2.ops.projo_wkvb import ProjoWKVb +from tilert.models.deepseek_v3_2.ops.projq_wqb import ProjqWqb +from tilert.models.deepseek_v3_2.ops.projx_wis import ProjxWis +from tilert.models.deepseek_v3_2.ops.rmsnorm_kv import KVRMSNorm +from tilert.models.deepseek_v3_2.ops.rmsnorm_projq_wqib import ( + RmsnormProjqWqib, + RmsnormProjqWqibAlgorithm, +) +from tilert.models.deepseek_v3_2.ops.rmsnorm_projx_wqkvia import ( + RMSNormProjxWqkvia, + RMSNormProjxWqkviaAlgorithm, +) +from tilert.models.deepseek_v3_2.ops.unproj_o_allreduce import ( + UnProjOAllReduce, + UnProjOAllReduceAlgorithm, +) + + +class Mla(SerializableTileRTModule): + """Implement the MLA operations.""" + + def __init__(self, model_args: ModelArgs, device_id: int, num_devices: int): + super().__init__(model_args=model_args, device_id=device_id, num_devices=num_devices) + + self.rmsnorm_projx_wqkvia = RMSNormProjxWqkvia( + model_args=model_args, device_id=device_id, num_devices=num_devices + ) + if model_args.arch_name == "glm_5": + self.rmsnorm_projx_wqkvia.algorithm = RMSNormProjxWqkviaAlgorithm.DECOUPLED + else: + self.rmsnorm_projx_wqkvia.algorithm = RMSNormProjxWqkviaAlgorithm.GENERAL + self.register_op(self.rmsnorm_projx_wqkvia) + + self.layernorm_rope_rotate = LayerNormRoPERotate( + model_args=model_args, device_id=device_id, num_devices=num_devices + ) + self.register_op(self.layernorm_rope_rotate) + + self.rmsnorm_projq_wqib = RmsnormProjqWqib( + model_args=model_args, device_id=device_id, num_devices=num_devices + ) + if model_args.arch_name == "glm_5": + self.rmsnorm_projq_wqib.algorithm = RmsnormProjqWqibAlgorithm.FP16MMA + else: + self.rmsnorm_projq_wqib.algorithm = RmsnormProjqWqibAlgorithm.BF16 + self.register_op(self.rmsnorm_projq_wqib) + + self.projx_wis = ProjxWis( + model_args=model_args, device_id=device_id, num_devices=num_devices + ) + self.register_op(self.projx_wis) + + self.projq_wqb = ProjqWqb( + model_args=model_args, device_id=device_id, num_devices=num_devices + ) + self.register_op(self.projq_wqb) + + self.rmsnorm_kv = KVRMSNorm( + model_args=model_args, device_id=device_id, num_devices=num_devices + ) + self.register_op(self.rmsnorm_kv) + + self.projo_wkvb = ProjoWKVb( + model_args=model_args, device_id=device_id, num_devices=num_devices + ) + self.register_op(self.projo_wkvb) + + self.unproj_o_allreduce = UnProjOAllReduce( + model_args=model_args, + device_id=device_id, + num_devices=num_devices, + algorithm=UnProjOAllReduceAlgorithm.FP8MMA, + ) + + if model_args.arch_name == "glm_5": + self.unproj_o_allreduce.algorithm = UnProjOAllReduceAlgorithm.FP16MMA + + self.register_op(self.unproj_o_allreduce) + + self.kv_cache: torch.Tensor | None = None + self.pe_cache: torch.Tensor | None = None + self.ki_cache: torch.Tensor | None = None + + def get_cache_vars(self) -> list[torch.Tensor]: + cache_seq_len = self.model_args.max_seq_len + self.model_args.kv_cache_pad + bs_args = (self.model_args.max_batch_size, cache_seq_len) + if self.kv_cache is None: + kv_dim = self.model_args.kv_lora_rank + self.kv_cache = torch.zeros( + *bs_args, kv_dim, dtype=torch.bfloat16, device=f"cuda:{self.device_id}" + ) + if self.pe_cache is None: + pe_dim = self.model_args.qk_rope_head_dim + self.pe_cache = torch.zeros( + *bs_args, pe_dim, dtype=torch.bfloat16, device=f"cuda:{self.device_id}" + ) + if self.ki_cache is None: + ki_dim = self.model_args.index_head_dim + self.ki_cache = torch.zeros( + *bs_args, ki_dim, dtype=torch.bfloat16, device=f"cuda:{self.device_id}" + ) + return [*super().get_cache_vars(), self.ki_cache, self.kv_cache, self.pe_cache] diff --git a/python/models/deepseek_v3_2/modules/mlp.py b/python/models/deepseek_v3_2/modules/mlp.py new file mode 100644 index 0000000..1e9a327 --- /dev/null +++ b/python/models/deepseek_v3_2/modules/mlp.py @@ -0,0 +1,47 @@ +from tilert.models.base import SerializableTileRTModule +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.models.deepseek_v3_2.modules.mla import Mla +from tilert.models.deepseek_v3_2.ops.down_allreduce import DownAllReduce +from tilert.models.deepseek_v3_2.ops.rmsnorm_up_gate_silu import ( + RMSNormUpGateSiLU, + RMSNormUpGateSiLUAlgorithm, +) + + +class Mlp(SerializableTileRTModule): + """Implement the MLP operations.""" + + def __init__(self, model_args: ModelArgs, device_id: int, num_devices: int): + super().__init__(model_args=model_args, device_id=device_id, num_devices=num_devices) + + self.rmsnorm_mlp_up_gate_silu = RMSNormUpGateSiLU( + model_args=model_args, + device_id=device_id, + num_devices=num_devices, + ) + if model_args.arch_name == "glm_5": + self.rmsnorm_mlp_up_gate_silu.algorithm = RMSNormUpGateSiLUAlgorithm.FP16MMA + self.register_op(self.rmsnorm_mlp_up_gate_silu) + self.rmsnorm_mlp_down = DownAllReduce( + model_args=model_args, device_id=device_id, num_devices=num_devices + ) + self.register_op(self.rmsnorm_mlp_down) + + +class MlpBlock(SerializableTileRTModule): + """Implement the MOE block operations.""" + + def __init__( + self, model_args: ModelArgs, device_id: int, num_devices: int, remove_selected: bool = False + ): + super().__init__( + model_args=model_args, + device_id=device_id, + num_devices=num_devices, + remove_selected=remove_selected, + ) + + self.mla = Mla(model_args=model_args, device_id=device_id, num_devices=num_devices) + self.register_op(self.mla) + self.mlp = Mlp(model_args=model_args, device_id=device_id, num_devices=num_devices) + self.register_op(self.mlp) diff --git a/python/models/deepseek_v3_2/modules/moe.py b/python/models/deepseek_v3_2/modules/moe.py new file mode 100644 index 0000000..f343e79 --- /dev/null +++ b/python/models/deepseek_v3_2/modules/moe.py @@ -0,0 +1,51 @@ +from tilert.models.base import SerializableTileRTModule +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.models.deepseek_v3_2.modules.mla import Mla +from tilert.models.deepseek_v3_2.ops.expert_down_allreduce import ExpertDownAllReduce +from tilert.models.deepseek_v3_2.ops.expert_sel_up_gate_silu import ( + ExpertSelectUpGateSiLU, + ExpertSelectUpGateSiLUAlgorithm, +) +from tilert.models.deepseek_v3_2.ops.rmsnorm_expert_proj import RMSNormExpertProj + + +class Moe(SerializableTileRTModule): + """Implement the MOE operations.""" + + def __init__(self, model_args: ModelArgs, device_id: int, num_devices: int): + super().__init__(model_args=model_args, device_id=device_id, num_devices=num_devices) + + self.rmsnorm_expert_proj = RMSNormExpertProj( + model_args=model_args, device_id=device_id, num_devices=num_devices + ) + self.register_op(self.rmsnorm_expert_proj) + + self.exp_sel_up_gate_silu = ExpertSelectUpGateSiLU( + model_args=model_args, device_id=device_id, num_devices=num_devices + ) + if model_args.arch_name == "glm_5": + self.exp_sel_up_gate_silu.algorithm = ExpertSelectUpGateSiLUAlgorithm.FP16MMA + self.register_op(self.exp_sel_up_gate_silu) + self.expert_down_allreduce = ExpertDownAllReduce( + model_args=model_args, device_id=device_id, num_devices=num_devices + ) + self.register_op(self.expert_down_allreduce) + + +class MoeBlock(SerializableTileRTModule): + """Implement the MOE block operations.""" + + def __init__( + self, model_args: ModelArgs, device_id: int, num_devices: int, remove_selected: bool = False + ): + super().__init__( + model_args=model_args, + device_id=device_id, + num_devices=num_devices, + remove_selected=remove_selected, + ) + + self.mla = Mla(model_args=model_args, device_id=device_id, num_devices=num_devices) + self.register_op(self.mla) + self.moe = Moe(model_args=model_args, device_id=device_id, num_devices=num_devices) + self.register_op(self.moe) diff --git a/python/models/deepseek_v3_2/modules/mtp.py b/python/models/deepseek_v3_2/modules/mtp.py new file mode 100644 index 0000000..fd43e0e --- /dev/null +++ b/python/models/deepseek_v3_2/modules/mtp.py @@ -0,0 +1,47 @@ +import torch + +from tilert.models.base import SerializableTileRTModule +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.models.deepseek_v3_2.modules.moe import MoeBlock +from tilert.models.deepseek_v3_2.modules.mtp_preprocess import MTPPreprocessLayer +from tilert.models.deepseek_v3_2.ops import RMSNormHeadProj + + +class MTP(SerializableTileRTModule): + """MTP module.""" + + def __init__(self, model_args: ModelArgs, device_id: int, num_devices: int): + super().__init__(model_args=model_args, device_id=device_id, num_devices=num_devices) + + self.embed_tokens_weight = None + self.freqs_cis = None + + mtp_layer_id = self.model_args.n_layers + self.register_op( + MTPPreprocessLayer(self.model_args, self.num_devices, device_id), + prefix=f"layer_{mtp_layer_id}_", + suffix=f"_dev_{device_id}", + ) + self.register_op( + MoeBlock(model_args=model_args, device_id=device_id, num_devices=num_devices), + prefix=f"layer_{mtp_layer_id}_", + suffix=f"_dev_{device_id}", + ) + self.register_op( + RMSNormHeadProj(model_args=model_args, device_id=device_id, num_devices=num_devices), + prefix=f"layer_{mtp_layer_id}_", + suffix=f"_dev_{device_id}", + retain_weights=True, + ) + + def init_tilert_weights(self, state_dicts: dict[str, torch.Tensor]) -> None: + self.embed_tokens_weight = state_dicts["model.embed_tokens.weight"] + self.freqs_cis = state_dicts["freqs_cis"] + super().init_tilert_weights(state_dicts) + + def get_weights_list(self) -> list[torch.Tensor]: + return [ + self.embed_tokens_weight, + self.freqs_cis, + *super().get_weights_list(), + ] diff --git a/python/models/deepseek_v3_2/modules/mtp_preprocess.py b/python/models/deepseek_v3_2/modules/mtp_preprocess.py new file mode 100644 index 0000000..dc094eb --- /dev/null +++ b/python/models/deepseek_v3_2/modules/mtp_preprocess.py @@ -0,0 +1,244 @@ +"""MTP preprocess layer for DeepSeek v3.""" + +from dataclasses import dataclass + +import torch + +from tilert.models.base import TileRTModule, TilertWeightsConverter +from tilert.models.common import init_func, linear +from tilert.models.deepseek_v3_2.model_args import ModelArgs + +__all__ = [ + "mtp_preprocess_layer", + "MTPPreprocessLayer", + "MTPPreprocessRefWeightsAlias", + "MTPPreprocessTilertWeightsAlias", + "MTPPreprocessWeightsConverter", +] + + +def mtp_preprocess_layer( + params: list[torch.Tensor], + temp_vars: list[torch.Tensor], + profile_logs: torch.Tensor, +) -> torch.Tensor: + """MTP preprocess layer op for DeepSeek v3. + + Output is in temp_vars[28] (eh_proj) for DSA temp vars layout. + """ + return torch.ops.tilert.mtp_preprocess_layer(params, temp_vars, profile_logs) + + +@dataclass +class MTPPreprocessRefWeightsAlias: + """Reference (golden/PyTorch) weight keys for MTP preprocess.""" + + embedding_rmsnorm = "enorm.weight" + hidden_rmsnorm = "hnorm.weight" + eh_proj = "eh_proj.weight" + + @property + def ref_tensor_alias(self) -> list[str]: + return [ + self.embedding_rmsnorm, + self.hidden_rmsnorm, + self.eh_proj, + ] + + def __call__(self) -> list[str]: + return self.ref_tensor_alias + + +@dataclass +class MTPPreprocessTilertWeightsAlias: + """TileRT weight keys for MTP preprocess.""" + + embedding_rmsnorm_gamma = "embedding_rmsnorm_gamma" + hidden_rmsnorm_gamma = "hidden_rmsnorm_gamma" + eh_proj_weights = "eh_proj_weights" + + @property + def tilert_tensor_alias(self) -> list[str]: + return [ + self.embedding_rmsnorm_gamma, + self.hidden_rmsnorm_gamma, + self.eh_proj_weights, + ] + + def __call__(self) -> list[str]: + return self.tilert_tensor_alias + + +class MTPPreprocessWeightsConverter(TilertWeightsConverter): + """Converts ref-format weights to TileRT format for MTP preprocess.""" + + def convert_to_tilert(self, weights: list[torch.Tensor], device_id: int) -> list[torch.Tensor]: + """ + Convert ref weights to TileRT format for a specific device. + + Args: + weights: [embedding_rmsnorm_gamma, hidden_rmsnorm_gamma, eh_proj.weight] + Ref format: enorm.weight [7168], hnorm.weight [7168], + eh_proj.weight [7168, 14336]. + device_id: Target device ID for weight placement. + + Returns: + MTPPreprocessParams with converted weights for device_id. + """ + device = torch.device(f"cuda:{device_id}") + embedding_rmsnorm_gamma, hidden_rmsnorm_gamma, eh_proj_weight = weights + + embedding_rmsnorm_gamma = embedding_rmsnorm_gamma.to(device=device, dtype=torch.float32) + hidden_rmsnorm_gamma = hidden_rmsnorm_gamma.to(device=device, dtype=torch.float32) + # eh_proj: [out, in] = [7168, 14336]; split on dim=1 -> 8 x [7168, 1792] + # Reshape: [7168, 1792] -> [128, 56, 7, 256] -> transpose(1,2) -> [128, 7, 56, 256] + # eh_proj_weight_splited = torch.chunk(eh_proj_weight, self.num_devices, dim=1) + eh_proj_weights = ( + eh_proj_weight.reshape( + 128, self.model_args.dim // 128, self.model_args.dim * 2 // 256 // 8, 256 + ) + .transpose(1, 2) + .contiguous() + .to(device=device, dtype=torch.bfloat16) + ) + return [embedding_rmsnorm_gamma, hidden_rmsnorm_gamma, eh_proj_weights] + + +class MTPPreprocessLayer(TileRTModule): + """MTP preprocess layer: RMSNorm(embedding), RMSNorm(hidden), concat & project.""" + + def __init__( + self, + model_args: ModelArgs, + num_devices: int, + device_id: int, + ref_weights_alias: MTPPreprocessRefWeightsAlias | None = None, + ) -> None: + super().__init__( + self.__class__.__name__, + model_args=model_args, + num_devices=num_devices, + device_id=device_id, + ) + self.tilert_weights_alias = MTPPreprocessTilertWeightsAlias() + self.ref_weights_alias = ( + ref_weights_alias if ref_weights_alias is not None else MTPPreprocessRefWeightsAlias() + ) + self.hidden_size = model_args.dim + + self.tilert_embedding_rmsnorm_gamma: torch.Tensor | None = None + self.tilert_hidden_rmsnorm_gamma: torch.Tensor | None = None + self.tilert_eh_proj_weights: torch.Tensor | None = None + + self.ref_embedding_rmsnorm_gamma: torch.Tensor | None = None + self.ref_hidden_rmsnorm_gamma: torch.Tensor | None = None + self.ref_eh_proj_weight: torch.Tensor | None = None + + def get_weights_list(self) -> list[torch.Tensor]: + return [ + self.tilert_embedding_rmsnorm_gamma, + self.tilert_hidden_rmsnorm_gamma, + self.tilert_eh_proj_weights, + ] + + def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Repeat ref weights for each device (for init_tilert_weights from ref).""" + embedding_gamma = weights_map[self.ref_weights_alias.embedding_rmsnorm] + hidden_gamma = weights_map[self.ref_weights_alias.hidden_rmsnorm] + eh_proj_weights = weights_map[self.ref_weights_alias.eh_proj] + return { + self.tilert_weights_alias.embedding_rmsnorm_gamma: ( + embedding_gamma[None, ...].repeat(self.num_devices, 1) + ), + self.tilert_weights_alias.hidden_rmsnorm_gamma: ( + hidden_gamma[None, ...].repeat(self.num_devices, 1) + ), + self.tilert_weights_alias.eh_proj_weights: ( + eh_proj_weights[None, ...] + .reshape( + self.model_args.dim, + self.num_devices, + self.model_args.dim * 2 // self.num_devices, + ) + .transpose(0, 1) + ), + } + + def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """Load ref-format weights (enorm.weight, hnorm.weight, eh_proj.weight).""" + self.ref_embedding_rmsnorm_gamma = state_dict[self.ref_weights_alias.embedding_rmsnorm] + self.ref_hidden_rmsnorm_gamma = state_dict[self.ref_weights_alias.hidden_rmsnorm] + self.ref_eh_proj_weight = state_dict[self.ref_weights_alias.eh_proj] + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """ + Load TileRT weights from state_dict. + + state_dict may use: + - Full keys: layer_{layer_id}_{alias}_dev_{device_id} + - Short keys: embedding_rmsnorm_gamma, hidden_rmsnorm_gamma, eh_proj_weights + - Ref keys: enorm.weight, hnorm.weight, eh_proj.weight (then convert) + """ + converter = MTPPreprocessWeightsConverter(self.model_args, self.num_devices) + params = converter.convert_to_tilert( + [state_dict[k] for k in self.tilert_weights_alias()], self.device_id + ) + self.tilert_embedding_rmsnorm_gamma = params[0] + self.tilert_hidden_rmsnorm_gamma = params[1] + self.tilert_eh_proj_weights = params[2] + + def init_random_weights(self) -> dict[str, torch.Tensor]: + """Initialize random ref weights and convert to TileRT for this device.""" + embedding_gamma = init_func(torch.randn(self.hidden_size, dtype=torch.float32)) + hidden_gamma = init_func(torch.randn(self.hidden_size, dtype=torch.float32)) + eh_proj_weights = init_func( + torch.randn(self.hidden_size, self.hidden_size * 2, dtype=torch.bfloat16) + ) + return { + self.ref_weights_alias.embedding_rmsnorm: embedding_gamma, + self.ref_weights_alias.hidden_rmsnorm: hidden_gamma, + self.ref_weights_alias.eh_proj: eh_proj_weights, + } + + def golden_forward( + self, + x: torch.Tensor, + last_hidden_states: torch.Tensor, + ) -> torch.Tensor: + """ + Reference forward: enorm(x), hnorm(last_hidden), concat & eh_proj. + + Args: + x: [batch, seq_len, hidden_size] embedded tokens + last_hidden_states: [batch, seq_len, hidden_size] previous hidden + + Returns: + [batch, seq_len, hidden_size] projected hidden + """ + assert self.ref_embedding_rmsnorm_gamma is not None + assert self.ref_hidden_rmsnorm_gamma is not None + assert self.ref_eh_proj_weight is not None + + future_norm = torch.nn.functional.rms_norm( + x.float(), + [x.size(-1)], + self.ref_embedding_rmsnorm_gamma, + 1e-6, + ) + prev_norm = torch.nn.functional.rms_norm( + last_hidden_states.float(), + [last_hidden_states.size(-1)], + self.ref_hidden_rmsnorm_gamma, + 1e-6, + ) + combined = torch.cat([future_norm, prev_norm], dim=-1) + return linear(combined, self.ref_eh_proj_weight) + + def tilert_forward( + self, + params: list[torch.Tensor], + temp_vars: list[torch.Tensor], + profile_logs: torch.Tensor, + ) -> torch.Tensor: + """Run TileRT mtp_preprocess_layer op.""" + return mtp_preprocess_layer(params, temp_vars, profile_logs) diff --git a/python/models/deepseek_v3_2/ops/__init__.py b/python/models/deepseek_v3_2/ops/__init__.py new file mode 100644 index 0000000..e832905 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/__init__.py @@ -0,0 +1,109 @@ +"""Core operations for deepseek v3.2.""" + +from tilert.models.deepseek_v3_2.ops.down_allreduce import ( + DownAllReduce, + down_allreduce, + down_allreduce_glm5, +) +from tilert.models.deepseek_v3_2.ops.eh_proj_allreduce import EHProjAllReduce, eh_proj_allreduce +from tilert.models.deepseek_v3_2.ops.expert_down_allreduce import ( + ExpertDownAllReduce, + expert_down_allreduce, +) +from tilert.models.deepseek_v3_2.ops.expert_sel_up_gate_silu import ( + ExpertSelectUpGateSiLU, + ExpertSelectUpGateSiLUAlgorithm, +) +from tilert.models.deepseek_v3_2.ops.expert_select import expert_select +from tilert.models.deepseek_v3_2.ops.flash_sparse_mla import flash_sparse_mla +from tilert.models.deepseek_v3_2.ops.layernorm_rope_rotate import layernorm_rope_rotate +from tilert.models.deepseek_v3_2.ops.projo_wkvb import projo_wkvb +from tilert.models.deepseek_v3_2.ops.projq_wqb import projq_wqb +from tilert.models.deepseek_v3_2.ops.projx_wis import projx_wis +from tilert.models.deepseek_v3_2.ops.qkv_rope import ( + QKVRoPE, + QKVRoPERefWeightsAlias, + QKVRoPETilertWeightsAlias, + qkv_rope, +) +from tilert.models.deepseek_v3_2.ops.rmsnorm_expert_proj import RMSNormExpertProj +from tilert.models.deepseek_v3_2.ops.rmsnorm_head_proj import RMSNormHeadProj +from tilert.models.deepseek_v3_2.ops.rmsnorm_kv import rmsnorm_kv +from tilert.models.deepseek_v3_2.ops.rmsnorm_projq_wqib import ( + RmsnormProjqWqib, + RmsnormProjqWqibAlgorithm, + RmsnormProjqWqibWeightsConverter, +) +from tilert.models.deepseek_v3_2.ops.rmsnorm_projx_wqkvia import ( + RMSNormProjxWqkvia, + RMSNormProjxWqkviaAlgorithm, + projx_wqkvia, + rmsnorm_projx_wqkvia, +) +from tilert.models.deepseek_v3_2.ops.rmsnorm_quant import rmsnorm_quant +from tilert.models.deepseek_v3_2.ops.rmsnorm_up_gate_silu import ( + RMSNormUpGateSiLU, + RMSNormUpGateSiLUAlgorithm, +) +from tilert.models.deepseek_v3_2.ops.rotate import ( + Rotate, + RotateRefWeightsAlias, + RotateTilertWeightsAlias, + rotate, + rotate_activation, +) +from tilert.models.deepseek_v3_2.ops.sparse_index import sparse_index, sparse_index_topk +from tilert.models.deepseek_v3_2.ops.topk import TopK, topk_accurate, topk_approximate +from tilert.models.deepseek_v3_2.ops.unproj_o_allreduce import ( + UnProjOAllReduce, + UnProjOAllReduceAlgorithm, + unproj_o_allreduce, +) +from tilert.models.deepseek_v3_2.ops.up_gate_silu import up_gate_silu + +__all__ = [ + "down_allreduce", + "down_allreduce_glm5", + "DownAllReduce", + "expert_down_allreduce", + "ExpertDownAllReduce", + "expert_select", + "up_gate_silu", + "rmsnorm_projx_wqkvia", + "projx_wqkvia", + "rmsnorm_kv", + "unproj_o_allreduce", + "projo_wkvb", + "projq_wqb", + "rotate", + "rotate_activation", + "Rotate", + "RotateRefWeightsAlias", + "RotateTilertWeightsAlias", + "layernorm_rope_rotate", + "TopK", + "topk_approximate", + "topk_accurate", + "sparse_index", + "sparse_index_topk", + "flash_sparse_mla", + "projx_wis", + "qkv_rope", + "QKVRoPE", + "QKVRoPERefWeightsAlias", + "QKVRoPETilertWeightsAlias", + "eh_proj_allreduce", + "rmsnorm_quant", + "RmsnormProjqWqib", + "RmsnormProjqWqibAlgorithm", + "RmsnormProjqWqibWeightsConverter", + "RMSNormExpertProj", + "RMSNormProjxWqkvia", + "RMSNormProjxWqkviaAlgorithm", + "RMSNormUpGateSiLU", + "UnProjOAllReduce", + "UnProjOAllReduceAlgorithm", + "RMSNormHeadProj", + "ExpertSelectUpGateSiLU", + "ExpertSelectUpGateSiLUAlgorithm", +] diff --git a/python/models/deepseek_v3_2/ops/down_allreduce.py b/python/models/deepseek_v3_2/ops/down_allreduce.py new file mode 100644 index 0000000..dfb5a81 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/down_allreduce.py @@ -0,0 +1,384 @@ +"""DownAllreduce operation module.""" + +from collections.abc import Callable +from dataclasses import dataclass +from enum import Enum + +# import torch.nn.functional as F +import torch + +from tilert.models.base import TileRTModule +from tilert.models.common import weight_dequant +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.models.deepseek_v3_2.ops.expert_down_allreduce import ( + ExpertDownAllReduceWeightsConverter, +) +from tilert.utils import get_profile_log_tensor + +__all__ = [ + "down_allreduce", + "down_allreduce_glm5", + "DownAllReduceAlgorithm", + "DownAllReduce", + "DownAllReduceTilertWeightsAlias", +] + + +def down_allreduce( + vec_in: torch.Tensor, + mat_in: torch.Tensor, + mat_scale: torch.Tensor, + x_in: torch.Tensor, + flag: int, + vec_out: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """ + Fused operation of down and allreduce. + + Args: + vec_in: Input tensor. + mat_in: Input tensor. + mat_scale: Input tensor. + x_in: Input tensor. + flag: Input flag. + vec_out: Output tensor. + profile_logs: Profile logs tensor. This is a 1D tensor of shape + (num_sms,) to store the profile logs of the down_allreduce + operation, where num_sms is the number of SMs on the + device. + """ + torch.ops.tilert.down_allreduce_op( + vec_in, + mat_in, + mat_scale, + x_in, + flag, + vec_out, + profile_logs, + ) + + +def down_allreduce_glm5( + vec_in: torch.Tensor, + mat_in: torch.Tensor, + mat_scale: torch.Tensor, + x_in: torch.Tensor, + flag: int, + vec_out: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """ + Fused operation of down and allreduce. + + Args: + vec_in: Input tensor. + mat_in: Input tensor. + mat_scale: Input tensor. + x_in: Input tensor. + flag: Input flag. + vec_out: Output tensor. + profile_logs: Profile logs tensor. This is a 1D tensor of shape + (num_sms,) to store the profile logs of the down_allreduce + operation, where num_sms is the number of SMs on the + device. + """ + torch.ops.tilert.down_allreduce_glm5_op( + vec_in, + mat_in, + mat_scale, + x_in, + flag, + vec_out, + profile_logs, + ) + + +class DownAllReduceAlgorithm(Enum): + """DownAllReduce algorithm""" + + GENERAL = "general" + + +DownAllReduceWeightsConverter = ExpertDownAllReduceWeightsConverter + + +@dataclass +class DownAllReduceTilertWeightsAlias: + """TileRT weights alias for DownAllReduce.""" + + down_weights = "down_weights" + down_scales = "down_scales" + + @property + def tilert_tensor_alias(self) -> list[str]: + return [self.down_weights, self.down_scales] + + def __call__(self) -> list[str]: + return self.tilert_tensor_alias + + +class DownAllReduce(TileRTModule): + """DownAllReduce module""" + + def __init__( + self, + model_args: ModelArgs, + device_id: int, + num_devices: int, + algorithm: DownAllReduceAlgorithm = DownAllReduceAlgorithm.GENERAL, + ): + super().__init__( + self.__class__.__name__, + device_id=device_id, + model_args=model_args, + num_devices=num_devices, + ) + + self.arch_name = self.model_args.arch_name + self.dim = self.model_args.dim + + self.inter_dim = self.model_args.inter_dim + self.moe_inter_dim = self.model_args.moe_inter_dim + self.moe_inter_dim_per_device = self.moe_inter_dim // self.num_devices + self.inter_dim_per_device = self.inter_dim // self.num_devices + # effective number of experts + self.n_experts: int = self.inter_dim_per_device // self.moe_inter_dim_per_device + self.block_size = self.model_args.block_size + self.dim_scale_dim = self.dim // self.block_size + self.in_scale_dim = self.inter_dim // self.block_size + self.moe_inter_scale_dim_per_device = self.moe_inter_dim_per_device // self.block_size + self.algorithm = algorithm + + # reference weights + self.ref_down: torch.Tensor | None = None + + # tilert weights + self.tilert_weights: torch.Tensor | None = None + self.tilert_scales: torch.Tensor | None = None + + # tilert vars + self.hidden_out: torch.Tensor | None = None + + self.profile_logs: torch.Tensor | None = None + self.is_init = False + + # tilert_funcs + self.down_allreduce_func: Callable | None = None + + if self.arch_name == "deepseek_v3_2": + self.down_allreduce_func = down_allreduce + elif self.arch_name == "glm_5": + self.down_allreduce_func = down_allreduce_glm5 + else: + raise ValueError(f"Unsupported architecture: {self.arch_name}") + + self.tilert_weights_alias = DownAllReduceTilertWeightsAlias() + + # for device sharding, corresponding to the output of device_sharding + # and input of tilert_forward + self.tensor_alias: list[str] = [ + "down_weights", + "down_scales", + ] + + # reference tensor aliases + self.ref_tensor_alias: list[str] = [ + "mlp.down_proj.weight", + "mlp.down_proj.weight_scale_inv", + ] + + @property + def tilert_tensor_alias(self) -> list[str]: + return self.tilert_weights_alias.tilert_tensor_alias + + def get_weights_list(self) -> list[torch.Tensor]: + """ + Get the weights list. + + Returns: + List of weights. + """ + return [self.tilert_weights, self.tilert_scales] + + def device_sharding( + self, + weights_dict: dict[str, torch.Tensor], + key_prefix: str, # e.g. model.layers.{layer_id}.mlp + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Device sharding. + + Args: + weights_dict: Dictionary of weights. + key_prefix: Key prefix. + Returns: + Tuple of weights. + """ + down_proj_weight_key = f"{key_prefix}.down_proj.weight" + down_proj_scale_key = f"{key_prefix}.down_proj.weight_scale_inv" + down_proj_weight = weights_dict[down_proj_weight_key] + down_proj_scale = weights_dict[down_proj_scale_key] + # To match the old convertcode + down_proj_weight = down_proj_weight.reshape( + self.dim, self.n_experts, self.num_devices, self.moe_inter_dim_per_device + ) + down_proj_weight_splited = torch.split(down_proj_weight, 1, dim=2) + + down_proj_weight_splited = [ + down_proj_weight_splited[i] + .reshape(self.dim, self.n_experts, self.moe_inter_dim_per_device) + .transpose(0, 1) + .contiguous() + for i in range(self.num_devices) + ] + + down_proj_scale = down_proj_scale.reshape( + self.dim_scale_dim, + self.n_experts, + self.num_devices, + self.moe_inter_scale_dim_per_device, + ) + down_proj_scale_splited = torch.split(down_proj_scale, 1, dim=2) + down_proj_scale_splited = [ + down_proj_scale_splited[i] + .reshape(self.dim_scale_dim, self.n_experts, self.moe_inter_scale_dim_per_device) + .transpose(0, 1) + .contiguous() + for i in range(self.num_devices) + ] + down_weights = torch.stack(down_proj_weight_splited, dim=0) + down_scales = torch.stack(down_proj_scale_splited, dim=0) + return down_weights.contiguous(), down_scales.contiguous() + + def init_reference_weights( + self, + state_dict: dict[str, torch.Tensor], + key_prefix: str, # e.g. model.layers.{layer_id}.mlp + device_id: int = 0, + ) -> None: + """ + Initialize the reference weights. + + Args: + state_dict: State dictionary. + device_id: Device ID. + """ + sharded_list = self.device_sharding(state_dict, key_prefix) + + down_weights = sharded_list[0][device_id] + down_scales = sharded_list[1][device_id] + + down_list = [ + weight_dequant(down_weight, down_scale) + for down_weight, down_scale in zip(down_weights, down_scales) + ] + self.ref_down = torch.stack(down_list, dim=0) + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """ + Initialize the tilert weights. + + Args: + state_dict: State dictionary. + """ + assert self.algorithm is not None, "Algorithm is not set" + self.tilert_weights, self.tilert_scales = DownAllReduceWeightsConverter( + self.model_args, self.num_devices + ).dispatch(self.algorithm, [state_dict[alias] for alias in self.tensor_alias]) + + def init_tilert_vars(self, batch_size: int, seq_len: int, device_id: int = 0) -> None: + """ + Initialize the tilert variables. + + Args: + batch_size: Batch size. + seq_len: Sequence length. + """ + # tilert vars + self.hidden_out = torch.zeros( + (batch_size, seq_len, self.dim), + dtype=torch.bfloat16, + device=f"cuda:{device_id}", + ) + self.profile_logs = get_profile_log_tensor(device=f"cuda:{device_id}") + self.is_init = True + + def init_random_weights(self, device_id: int = 0) -> None: + """Initialize the random weights.""" + scale_dtype = torch.float32 if self.arch_name == "glm_5" else torch.bfloat16 + down_weights = torch.randn( + self.dim, self.inter_dim, dtype=torch.bfloat16, device=f"cuda:{device_id}" + ).to(torch.float8_e4m3fn) + + inter_dim_scale_dim = self.inter_dim // self.block_size + dim_scale_dim = self.dim // self.block_size + down_scales = torch.randn( + dim_scale_dim, inter_dim_scale_dim, dtype=scale_dtype, device=f"cuda:{device_id}" + ) + tensor_list = [ + down_weights, + down_scales, + ] + state_dict = dict(zip(self.ref_tensor_alias, tensor_list)) + + self.init_reference_weights(state_dict, "mlp", device_id) + sharded_list = self.device_sharding(state_dict, "mlp") + + sharded_state_dict = { + alias: sharded_list[i][device_id] for i, alias in enumerate(self.tensor_alias) + } + self.init_tilert_weights(sharded_state_dict) + + def golden_forward( + self, + vec_in: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass for the down-project module. + + Args: + vec_in: Input vector. + + Returns: + Output tensor. + """ + assert self.ref_down is not None + bsz = vec_in.shape[0] + assert bsz == 1 + seq_len = vec_in.shape[1] + hidden_out_list = [] + for s in range(seq_len): + hidden_out_w2_list = [] + for i in range(self.n_experts): + hidden_out_w2_sel = vec_in[0, s, i].float() @ self.ref_down[i].float().T + hidden_out_w2_list.append(hidden_out_w2_sel) + hidden_out_w2 = torch.stack(hidden_out_w2_list, dim=0).to(torch.bfloat16) + hidden_out_w2 = torch.sum(hidden_out_w2, dim=0) + hidden_out_list.append(hidden_out_w2) + return torch.stack(hidden_out_list, dim=0)[None, ...] + + def tilert_forward( + self, + vec_in: torch.Tensor, + x_in: torch.Tensor, + flag: int, + ) -> torch.Tensor: + assert self.down_allreduce_func is not None + assert self.hidden_out is not None + self.down_allreduce_func( + vec_in, + self.tilert_weights, + self.tilert_scales, + x_in, + flag, + self.hidden_out, + self.profile_logs, + ) + return self.hidden_out + + def __call__( + self, + x_in: torch.Tensor, + ) -> torch.Tensor: + return self.golden_forward(x_in) diff --git a/python/models/deepseek_v3_2/ops/eh_proj_allreduce.py b/python/models/deepseek_v3_2/ops/eh_proj_allreduce.py new file mode 100644 index 0000000..309751a --- /dev/null +++ b/python/models/deepseek_v3_2/ops/eh_proj_allreduce.py @@ -0,0 +1,287 @@ +"""EHProjAllReduce operation module.""" + +from dataclasses import dataclass +from enum import Enum + +# import torch.nn.functional as F +import torch + +from tilert.models.base import TileRTModule, TilertWeightsConverter +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.utils import get_profile_log_tensor + +__all__ = [ + "eh_proj_allreduce", + "EHProjAllReduceTilertWeightsAlias", +] + + +def eh_proj_allreduce( + vec_in_enorm: torch.Tensor, + vec_in_hnorm: torch.Tensor, + w_eh: torch.Tensor, + flag: int, + vec_out: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """ + Fused operation of EHProj and allreduce. + + Args: + vec_in_enorm: Input tensor of shape (1, seq_len, 7168). + vec_in_hnorm: Input tensor of shape (1, seq_len, 7168). + w_eh: Input tensor of shape (7168, 1792) or (128, 7, 56, 256). + flag: Input tensor. + vec_out: Output tensor of shape (1, seq_len, 7168). + profile_logs: Profile logs tensor. This is a 1D tensor of shape + (num_sms,) to store the profile logs of the eh_proj_allreduce + operation, where num_sms is the number of SMs on the + device. + """ + dim = vec_in_enorm.shape[-1] + if dim == 7168: + func_call = torch.ops.tilert.eh_proj_allreduce_op + elif dim == 6144: + func_call = torch.ops.tilert.eh_proj_allreduce_glm5_op + else: + raise ValueError(f"Unsupported dimension: {dim}") + func_call(vec_in_enorm, vec_in_hnorm, w_eh, flag, vec_out, profile_logs) + + +class EHProjAllReduceAlgorithm(Enum): + """EHProjAllReduce algorithm""" + + GENERAL = "general" + + +class EHProjAllReduceWeightsConverter(TilertWeightsConverter): + """EHProj weights converter""" + + def convert_to_general(self, weights_list: list[torch.Tensor]) -> tuple[torch.Tensor]: + """ + Convert the weights to general format. + + Args: + weights_list: List of weights. + + Returns: + Tuple of weights. + """ + args = self.model_args + assert args.arch_name == "deepseek_v3_2" or args.arch_name == "glm_5" + dim = args.dim + num_sms = 128 + dim_per_sm = dim // num_sms + in_dim = dim * 2 + in_dim_per_device = in_dim // self.num_devices + stages = in_dim_per_device // 256 + + with torch.inference_mode(): + (proj_weights,) = weights_list + proj_weights = proj_weights.reshape(num_sms, dim_per_sm, stages, 256) + proj_weights = proj_weights.transpose(1, 2) + return (proj_weights.contiguous(),) + + +@dataclass +class EHProjAllReduceTilertWeightsAlias: + """TileRT weights alias for EHProjAllReduce.""" + + eh_proj_weights = "eh_proj_weights" + + @property + def tilert_tensor_alias(self) -> list[str]: + return [self.eh_proj_weights] + + def __call__(self) -> list[str]: + return self.tilert_tensor_alias + + +class EHProjAllReduce(TileRTModule): + """EHProjAllReduce module""" + + def __init__( + self, + model_args: ModelArgs, + num_devices: int, + algorithm: EHProjAllReduceAlgorithm = EHProjAllReduceAlgorithm.GENERAL, + ): + super().__init__( + self.__class__.__name__, + model_args=model_args, + num_devices=num_devices, + ) + + self.arch_name = self.model_args.arch_name + self.dim = self.model_args.dim + + self.algorithm = algorithm + + # reference weights + self.ref_proj: torch.Tensor | None = None + + # tilert weights + self.tilert_proj: torch.Tensor | None = None + + # tilert vars + self.hidden_out: torch.Tensor | None = None + + self.profile_logs: torch.Tensor | None = None + self.is_init = False + + self.tilert_weights_alias = EHProjAllReduceTilertWeightsAlias() + + # for device sharding, corresponding to the output of device_sharding + # and input of tilert_forward + self.tensor_alias: list[str] = [ + "eh_proj_weights", + ] + + # reference tensor aliases + self.ref_tensor_alias: list[str] = [ + "eh_proj.weight", + ] + + @property + def tilert_tensor_alias(self) -> list[str]: + return self.tilert_weights_alias.tilert_tensor_alias + + def get_weights_list(self) -> list[torch.Tensor]: + """ + Get the weights list. + + Returns: + List of weights. + """ + return [self.tilert_proj] + + def device_sharding( + self, + weights_dict: dict[str, torch.Tensor], + key_prefix: str | None = None, # e.g. model.layers.{layer_id} + ) -> tuple[torch.Tensor]: + """ + Device sharding. + + Args: + weights_dict: Dictionary of weights. + key_prefix: Key prefix. + Returns: + Tuple of weights. + """ + eh_proj_key = "eh_proj.weight" + if key_prefix is not None: + eh_proj_key = f"{key_prefix}.eh_proj.weight" + + eh_proj_weight = weights_dict[eh_proj_key] + in_dim = eh_proj_weight.shape[1] + out_dim = eh_proj_weight.shape[0] + in_dim_per_device = in_dim // self.num_devices + eh_proj_weight = eh_proj_weight.reshape(out_dim, self.num_devices, in_dim_per_device) + eh_proj_weight = eh_proj_weight.transpose(0, 1) + return (eh_proj_weight.contiguous(),) + + def init_reference_weights( + self, + state_dict: dict[str, torch.Tensor], + key_prefix: str | None = None, + device_id: int = 0, + ) -> None: + """ + Initialize the reference weights. + + Args: + state_dict: State dictionary. + device_id: Device ID. + """ + sharded_list = self.device_sharding(state_dict, key_prefix) + + eh_proj_weight = sharded_list[0][device_id] + + self.ref_proj = eh_proj_weight + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """ + Initialize the tilert weights. + + Args: + state_dict: State dictionary. + """ + assert self.algorithm is not None + (self.tilert_proj,) = EHProjAllReduceWeightsConverter( + self.model_args, self.num_devices + ).dispatch(self.algorithm, [state_dict[alias] for alias in self.tensor_alias]) + + def init_tilert_vars(self, batch_size: int, seq_len: int, device_id: int = 0) -> None: + """ + Initialize the tilert variables. + + Args: + batch_size: Batch size. + seq_len: Sequence length. + """ + # tilert vars + self.hidden_out = torch.zeros( + (batch_size, seq_len, self.dim), + dtype=torch.bfloat16, + device=f"cuda:{device_id}", + ) + self.profile_logs = get_profile_log_tensor(device=f"cuda:{device_id}") + self.is_init = True + + def init_random_weights(self, device_id: int = 0) -> None: + """Initialize the random weights.""" + proj_weights = torch.randn( + self.dim, self.dim * 2, dtype=torch.bfloat16, device=f"cuda:{device_id}" + ) + + tensor_list = [ + proj_weights, + ] + state_dict = dict(zip(self.ref_tensor_alias, tensor_list)) + + self.init_reference_weights(state_dict, None, device_id) + sharded_list = self.device_sharding(state_dict, None) + sharded_state_dict = { + alias: sharded_list[i][device_id] for i, alias in enumerate(self.tensor_alias) + } + self.init_tilert_weights(sharded_state_dict) + + def golden_forward( + self, + vec_in_enorm: torch.Tensor, + vec_in_hnorm: torch.Tensor, + device_id: int = 0, + ) -> torch.Tensor: + """ + Forward pass for the down-project module. + + Args: + vec_in_enorm: Input vector of shape (1, seq_len, 7168). + vec_in_hnorm: Input vector of shape (1, seq_len, 7168). + + Returns: + Output tensor. + """ + assert self.ref_proj is not None + bsz = vec_in_enorm.shape[0] + assert bsz == 1 + + vec_in_concat = torch.cat([vec_in_enorm, vec_in_hnorm], dim=-1) + dim_per_device = (self.dim * 2) // self.num_devices + vec_in_slice = vec_in_concat[ + ..., dim_per_device * device_id : dim_per_device * device_id + dim_per_device + ] + return vec_in_slice @ self.ref_proj.T + + def tilert_forward( + self, + vec_in_enorm: torch.Tensor, + vec_in_hnorm: torch.Tensor, + flag: int, + ) -> torch.Tensor: + assert self.hidden_out is not None + eh_proj_allreduce( + vec_in_enorm, vec_in_hnorm, self.tilert_proj, flag, self.hidden_out, self.profile_logs + ) + return self.hidden_out diff --git a/python/models/deepseek_v3_2/ops/expert_down_allreduce.py b/python/models/deepseek_v3_2/ops/expert_down_allreduce.py new file mode 100644 index 0000000..d49bc77 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/expert_down_allreduce.py @@ -0,0 +1,403 @@ +from collections.abc import Callable +from dataclasses import dataclass +from enum import Enum + +import torch + +from tilert.models.base import TileRTModule, TilertWeightsConverter +from tilert.models.common import weight_dequant +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.utils import get_profile_log_tensor + +__all__ = [ + "expert_down_allreduce", + "expert_down_allreduce_glm5", + "ExpertDownAllReduceAlgorithm", + "ExpertDownAllReduce", + "ExpertDownAllReduceTilertWeightsAlias", +] + + +VALID_SEQ_LENS = (1, 2, 4) + + +def expert_down_allreduce( + vec_in: torch.Tensor, + mat_in: torch.Tensor, + mat_scale: torch.Tensor, + indices: torch.Tensor, + scores: torch.Tensor, + x_in: torch.Tensor, + flag: int, + vec_out: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """ + Fused expert down + allreduce (deepseek_v3_2). + + Args: + vec_in: [1, seq_len, n_experts, 256], bfloat16. + mat_in: [n_experts, 6144, 256], float8_e4m3fn. + mat_scale: [n_experts, 1024, 2], bfloat16. + indices: [1, seq_len, 8], int32. + scores: [1, seq_len, 8], float32. + x_in: [1, seq_len, 6144], bfloat16. + flag: User flag. + vec_out: [1, seq_len, 6144], bfloat16 (output). + profile_logs: 1D tensor (num_sms,) for profile logs. + """ + torch.ops.tilert.expert_down_allreduce_op( + vec_in, mat_in, mat_scale, indices, scores, x_in, flag, vec_out, profile_logs + ) + + +def expert_down_allreduce_glm5( + vec_in: torch.Tensor, + mat_in: torch.Tensor, + mat_scale: torch.Tensor, + indices: torch.Tensor, + scores: torch.Tensor, + x_in: torch.Tensor, + flag: int, + vec_out: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """ + Fused expert down + allreduce (glm_5). + + Args: + vec_in: [1, seq_len, n_experts, 256], bfloat16. + mat_in: [n_experts, 6144, 256], float8_e4m3fn. + mat_scale: [n_experts, 1024, 2], bfloat16. + indices: [1, seq_len, 8], int32. + scores: [1, seq_len, 8], float32. + x_in: [1, seq_len, 6144], bfloat16. + flag: User flag. + vec_out: [1, seq_len, 6144], bfloat16 (output). + profile_logs: 1D tensor (num_sms,) for profile logs. + """ + torch.ops.tilert.expert_down_allreduce_glm5_op( + vec_in, mat_in, mat_scale, indices, scores, x_in, flag, vec_out, profile_logs + ) + + +class ExpertDownAllReduceAlgorithm(Enum): + """ExpertDownAllReduce algorithm.""" + + GENERAL = "general" + + +class ExpertDownAllReduceWeightsConverter(TilertWeightsConverter): + """ExpertDownAllReduce weights converter.""" + + @staticmethod + def _swizzle_qmma_16x32(mat_in: torch.Tensor) -> torch.Tensor: + assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32 + assert mat_in.dtype == torch.float8_e4m3fn + pre_shape = mat_in.shape[:-2] + mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4) + return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2) + + @staticmethod + def _swizzle_qmma_8x32(mat_in: torch.Tensor) -> torch.Tensor: + assert mat_in.shape[-2] == 8 and mat_in.shape[-1] == 32 + pre_shape = mat_in.shape[:-2] + return mat_in.reshape(*pre_shape, 8, 2, 4, 4).transpose(-2, -3).contiguous() + + def convert_to_general( + self, weights_list: list[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor]: + """Convert weights to general (tilert) format.""" + args = self.model_args + assert args.arch_name in ("deepseek_v3_2", "glm_5") + arch_name = args.arch_name + dim = args.dim + num_sms = 128 + dim_per_sm = dim // num_sms + dim_scale_dim = dim // args.block_size + + with torch.inference_mode(): + mat_in, scale_in = weights_list + exp_num = mat_in.shape[0] + mat_in_s = mat_in.reshape(exp_num, num_sms, dim_per_sm, 256) + mat_in_0 = mat_in_s[:, :, :16].reshape(exp_num, num_sms, 16, 8, 32).transpose(2, 3) + mat_in_0 = self._swizzle_qmma_16x32(mat_in_0).reshape(exp_num, 128, -1) + mat_in_1 = mat_in_s[:, :, 16:32].reshape(exp_num, num_sms, 16, 8, 32).transpose(2, 3) + mat_in_1 = self._swizzle_qmma_16x32(mat_in_1).reshape(exp_num, 128, -1) + mat_in_2 = mat_in_s[:, :, 32:48].reshape(exp_num, num_sms, 16, 8, 32).transpose(2, 3) + mat_in_2 = self._swizzle_qmma_16x32(mat_in_2).reshape(exp_num, 128, -1) + mat_in_swizzled = torch.cat([mat_in_0, mat_in_1, mat_in_2], dim=2) + if arch_name == "deepseek_v3_2": + mat_in_3 = mat_in_s[:, :, 48:56].reshape(exp_num, num_sms, 8, 8, 32).transpose(2, 3) + mat_in_3 = self._swizzle_qmma_8x32(mat_in_3).reshape(exp_num, 128, -1) + mat_in_swizzled = torch.cat([mat_in_0, mat_in_1, mat_in_2, mat_in_3], dim=2) + mat_in_swizzled = mat_in_swizzled.reshape(exp_num, dim, 256) + + mat_scale_tilert = ( + scale_in.reshape(exp_num, dim_scale_dim, 1, 2) + .repeat(1, 1, 16, 1) + .reshape(exp_num, num_sms, -1) + ) + padding_zeros = torch.zeros( + (exp_num, num_sms, 16 - mat_scale_tilert.shape[-1]), + dtype=scale_in.dtype, + device=scale_in.device, + ) + mat_scale_tilert = torch.cat([mat_scale_tilert, padding_zeros], dim=2) + mat_scale_tilert = mat_scale_tilert.reshape(exp_num, 1024, 2) + if arch_name == "glm_5": + if mat_scale_tilert.dtype != torch.float32: + print( + "Warning: ExpertDownAllReduceWeightsConverter: " + + f"mat_scale_tilert.dtype: {mat_scale_tilert.dtype} " + + "is not float32, convert to float32." + ) + mat_scale_tilert = mat_scale_tilert.to(torch.float32) + else: # DS v3.2, use bfloat16 for mat_scale_tilert + mat_scale_tilert = mat_scale_tilert.to(torch.bfloat16) + return mat_in_swizzled.contiguous(), mat_scale_tilert.contiguous() + + +@dataclass +class ExpertDownAllReduceTilertWeightsAlias: + """TileRT weights alias for ExpertDownAllReduce.""" + + exp_down_weights = "exp_down_weights" + exp_down_scales = "exp_down_scales" + + @property + def tilert_tensor_alias(self) -> list[str]: + return [self.exp_down_weights, self.exp_down_scales] + + def __call__(self) -> list[str]: + return self.tilert_tensor_alias + + +class ExpertDownAllReduce(TileRTModule): + """ExpertDownAllReduce module.""" + + def __init__( + self, + model_args: ModelArgs, + device_id: int, + num_devices: int, + algorithm: ExpertDownAllReduceAlgorithm = ExpertDownAllReduceAlgorithm.GENERAL, + ): + super().__init__( + self.__class__.__name__, + model_args=model_args, + device_id=device_id, + num_devices=num_devices, + ) + self.arch_name = self.model_args.arch_name + self.dim = self.model_args.dim + self.n_activated_experts: int = self.model_args.n_activated_experts + self.n_routed_experts: int = self.model_args.n_routed_experts + self.n_shared_experts: int = self.model_args.n_shared_experts + self.moe_inter_dim = self.model_args.moe_inter_dim + self.block_size = self.model_args.block_size + self.algorithm = algorithm + + self.ref_down: torch.Tensor | None = None + self.tilert_weights: torch.Tensor | None = None + self.tilert_scales: torch.Tensor | None = None + self.hidden_out: torch.Tensor | None = None + self.profile_logs: torch.Tensor | None = None + self.is_init = False + self.exp_down_allreduce_func: Callable | None = None + + if self.arch_name == "deepseek_v3_2": + self.exp_down_allreduce_func = expert_down_allreduce + elif self.arch_name == "glm_5": + self.exp_down_allreduce_func = expert_down_allreduce_glm5 + else: + raise ValueError(f"Unsupported architecture: {self.arch_name}") + + self.tilert_weights_alias = ExpertDownAllReduceTilertWeightsAlias() + self.tensor_alias = ["exp_down_weights", "exp_down_scales"] + self.ref_tensor_alias = ( + ["mlp.shared_experts.down_proj.weight"] + + [f"mlp.experts.{i}.down_proj.weight" for i in range(self.n_routed_experts)] + + ["mlp.shared_experts.down_proj.weight_scale_inv"] + + [f"mlp.experts.{i}.down_proj.weight_scale_inv" for i in range(self.n_routed_experts)] + ) + + @property + def tilert_tensor_alias(self) -> list[str]: + return self.tilert_weights_alias.tilert_tensor_alias + + def get_weights_list(self) -> list[torch.Tensor]: + return [self.tilert_weights, self.tilert_scales] + + @staticmethod + def process_down_weights( + key_prefix: str, + weights_hf: dict[str, torch.Tensor], + num_devices: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + down_proj_weight_key = f"{key_prefix}.down_proj.weight" + down_proj_scale_key = f"{key_prefix}.down_proj.weight_scale_inv" + down_proj_weight = weights_hf[down_proj_weight_key] + down_proj_scale = weights_hf[down_proj_scale_key] + + dim = down_proj_weight.shape[-2] + dim_scale_dim = down_proj_scale.shape[-2] + moe_inter_dim = down_proj_weight.shape[-1] + in_scale_dim = down_proj_scale.shape[-1] + moe_inter_dim_per_device = moe_inter_dim // num_devices + in_scale_dim_per_device = in_scale_dim // num_devices + + down_proj_weight = down_proj_weight.reshape(dim, num_devices, moe_inter_dim_per_device) + down_proj_weight = down_proj_weight.transpose(0, 1).reshape( + num_devices, 1, dim, moe_inter_dim_per_device + ) + down_proj_scale = down_proj_scale.reshape( + dim_scale_dim, num_devices, in_scale_dim_per_device + ) + down_proj_scale = down_proj_scale.transpose(0, 1).reshape( + num_devices, 1, dim_scale_dim, in_scale_dim_per_device + ) + return down_proj_weight, down_proj_scale + + def device_sharding( + self, + weights_dict: dict[str, torch.Tensor], + key_prefix: str, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert self.n_shared_experts == 1, "Only one shared expert is supported" + down_weights_list = [] + down_scales_list = [] + exp_prefix = f"{key_prefix}.shared_experts" + down_weights, down_scales = self.process_down_weights( + exp_prefix, weights_dict, self.num_devices + ) + down_weights_list.append(down_weights) + down_scales_list.append(down_scales) + for exp_id in range(self.n_routed_experts): + exp_prefix = f"{key_prefix}.experts.{exp_id}" + down_weights, down_scales = self.process_down_weights( + exp_prefix, weights_dict, self.num_devices + ) + down_weights_list.append(down_weights) + down_scales_list.append(down_scales) + down_weights = torch.cat(down_weights_list, dim=1) + down_scales = torch.cat(down_scales_list, dim=1) + return down_weights.contiguous(), down_scales.contiguous() + + def init_reference_weights( + self, + state_dict: dict[str, torch.Tensor], + key_prefix: str, + device_id: int = 0, + ) -> None: + sharded_list = self.device_sharding(state_dict, key_prefix) + down_weights = sharded_list[0][device_id] + down_scales = sharded_list[1][device_id] + + down_list = [ + weight_dequant(down_weight, down_scale) + for down_weight, down_scale in zip(down_weights, down_scales) + ] + self.ref_down = torch.stack(down_list, dim=0) + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + assert self.algorithm is not None, "Algorithm is not set" + self.tilert_weights, self.tilert_scales = ExpertDownAllReduceWeightsConverter( + self.model_args, self.num_devices + ).dispatch(self.algorithm, [state_dict[alias] for alias in self.tensor_alias]) + + def init_tilert_vars(self, batch_size: int, seq_len: int, device_id: int = 0) -> None: + self.hidden_out = torch.zeros( + (batch_size, seq_len, self.dim), + dtype=torch.bfloat16, + device=f"cuda:{device_id}", + ) + self.profile_logs = get_profile_log_tensor(device=f"cuda:{device_id}") + self.is_init = True + + def init_random_weights(self, device_id: int = 0) -> None: + down_weights = [ + torch.randn( + self.dim, self.moe_inter_dim, dtype=torch.bfloat16, device=f"cuda:{device_id}" + ).to(torch.float8_e4m3fn) + for _ in range(self.n_routed_experts + 1) + ] + dim_scale_dim = self.dim // self.block_size + moe_inter_dim_scale_dim = self.moe_inter_dim // self.block_size + scale_dtype = torch.float32 if self.arch_name == "glm_5" else torch.bfloat16 + down_scales = [ + torch.randn( + dim_scale_dim, + moe_inter_dim_scale_dim, + dtype=scale_dtype, + device=f"cuda:{device_id}", + ) + for _ in range(self.n_routed_experts + 1) + ] + state_dict = dict( + zip( + self.ref_tensor_alias, + [*down_weights, *down_scales], + ) + ) + self.init_reference_weights(state_dict, "mlp", device_id) + sharded_list = self.device_sharding(state_dict, "mlp") + sharded_state_dict = { + alias: sharded_list[i][device_id] for i, alias in enumerate(self.tensor_alias) + } + self.init_tilert_weights(sharded_state_dict) + + def golden_forward( + self, + vec_in: torch.Tensor, + indices: torch.Tensor, + scores: torch.Tensor, + ) -> torch.Tensor: + assert self.ref_down is not None + assert vec_in.dim() == 4 and vec_in.size(0) == 1 + seq_len = vec_in.shape[1] + hidden_out_list = [] + for s in range(seq_len): + hidden_out_w2_list = [] + hidden_out_w2_shared = vec_in[0, s, 0].float() @ self.ref_down[0].float().T + hidden_out_w2_list.append(hidden_out_w2_shared) + ref_down_sel = self.ref_down[1:][indices[0, s]] + for i in range(self.n_activated_experts): + hidden_out_w2_sel = vec_in[0, s, i + 1].float() @ ref_down_sel[i].float().T + hidden_out_w2_list.append(hidden_out_w2_sel * scores[0, s, i]) + hidden_out_w2 = torch.stack(hidden_out_w2_list, dim=0).to(torch.bfloat16) + hidden_out_w2 = torch.sum(hidden_out_w2, dim=0) + hidden_out_list.append(hidden_out_w2) + hidden_out = torch.stack(hidden_out_list, dim=0) + return hidden_out[None, ...] + + def tilert_forward( + self, + vec_in: torch.Tensor, + indices: torch.Tensor, + scores: torch.Tensor, + x_in: torch.Tensor, + flag: int, + ) -> torch.Tensor: + assert self.exp_down_allreduce_func is not None + assert self.hidden_out is not None + self.exp_down_allreduce_func( + vec_in, + self.tilert_weights, + self.tilert_scales, + indices, + scores, + x_in, + flag, + self.hidden_out, + self.profile_logs, + ) + return self.hidden_out + + def __call__( + self, + x_in: torch.Tensor, + indices: torch.Tensor, + scores: torch.Tensor, + ) -> torch.Tensor: + return self.golden_forward(x_in, indices, scores) diff --git a/python/models/deepseek_v3_2/ops/expert_sel_up_gate_silu.py b/python/models/deepseek_v3_2/ops/expert_sel_up_gate_silu.py new file mode 100644 index 0000000..50a0a67 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/expert_sel_up_gate_silu.py @@ -0,0 +1,729 @@ +"""ExpertSelectUpGateSiLU operation module.""" + +from dataclasses import dataclass +from enum import Enum + +import numpy as np + +# from typing import Any +import torch +import torch.nn.functional as F + +from tilert.models.base import TileRTModule, TilertWeightsConverter +from tilert.models.common import weight_dequant +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.profiler.utils import parse_profile_log_tensor +from tilert.utils import get_profile_log_tensor + +__all__ = [ + "ExpertSelectUpGateSiLUAlgorithm", + "ExpertSelectUpGateSiLU", + "ExpertSelectUpGateSiLURefWeightsAlias", + "ExpertSelectUpGateSiLUTilertWeightsAlias", + "expert_select_up_gate_silu", +] + + +def expert_select_up_gate_silu( + hidden_in: torch.Tensor, + scores_in: torch.Tensor, + bias_in: torch.Tensor, + experts_weights_in: torch.Tensor, + hidden_out: torch.Tensor, + expert_probs_out: torch.Tensor, + expert_indices_out: torch.Tensor, + profile_logs: torch.Tensor, + algorithm: str = "fp8mma", +) -> None: + """Expert SelectUpGateSiLU operation.""" + args_list = [ + hidden_in, + scores_in, + bias_in, + experts_weights_in, + hidden_out, + expert_probs_out, + expert_indices_out, + profile_logs, + algorithm, + ] + torch.ops.tilert.expert_select_up_gate_silu_op(*args_list) + + +@dataclass +class ExpertSelectUpGateSiLURefWeightsAlias: + """Reference weights alias for ExpertSelectUpGateSiLU.""" + + key_prefix: str = "mlp" + n_routed_experts: int = 256 + + @property + def ref_tensor_alias(self) -> list[str]: + n = self.n_routed_experts + return ( + [f"{self.key_prefix}.gate.e_score_correction_bias"] + + [f"{self.key_prefix}.shared_experts.gate_proj.weight"] + + [f"{self.key_prefix}.experts.{i}.gate_proj.weight" for i in range(n)] + + [f"{self.key_prefix}.shared_experts.up_proj.weight"] + + [f"{self.key_prefix}.experts.{i}.up_proj.weight" for i in range(n)] + + [f"{self.key_prefix}.shared_experts.gate_proj.weight_scale_inv"] + + [f"{self.key_prefix}.experts.{i}.gate_proj.weight_scale_inv" for i in range(n)] + + [f"{self.key_prefix}.shared_experts.up_proj.weight_scale_inv"] + + [f"{self.key_prefix}.experts.{i}.up_proj.weight_scale_inv" for i in range(n)] + ) + + def __call__(self) -> list[str]: + return self.ref_tensor_alias + + +@dataclass +class ExpertSelectUpGateSiLUTilertWeightsAlias: + """TileRT weights alias for ExpertSelectUpGateSiLU.""" + + exp_bias = "exp_bias" + exp_gate_weights = "exp_gate_weights" + exp_gate_scales = "exp_gate_scales" + exp_up_weights = "exp_up_weights" + exp_up_scales = "exp_up_scales" + + @property + def tilert_tensor_alias(self) -> list[str]: + return [ + self.exp_bias, + self.exp_gate_weights, + self.exp_gate_scales, + self.exp_up_weights, + self.exp_up_scales, + ] + + def __call__(self) -> list[str]: + return self.tilert_tensor_alias + + +class ExpertSelectUpGateSiLUAlgorithm(Enum): + """ExpertSelectUpGateSiLU algorithm""" + + FP8MMA = "fp8mma" + FP16MMA = "fp16mma" + + +class ExpertSelectUpGateSiLUWeightsConverter(TilertWeightsConverter): + """ExpertSelectUpGateSiLU weights converter""" + + @staticmethod + def _swizzle_qmma_16x32(mat_in: torch.Tensor) -> torch.Tensor: + assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32 + assert mat_in.dtype == torch.float8_e4m3fn + # PTX isa fig.88 + pre_shape = mat_in.shape[:-2] + mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4) + return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2) + + @staticmethod + def _swizzle_mma_16x32(mat_in: torch.Tensor) -> torch.Tensor: + assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32 + # PTX isa fig.88 + pre_shape = mat_in.shape[:-2] + mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4) + return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2) + + @staticmethod + def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor: + assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16 + # PTX isa fig.88 + pre_shape = mat_in.shape[:-2] + mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4) + return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2) + + @staticmethod + def tilert_to_tilert_144sm( + mat_in: torch.Tensor, mat_scale_in: torch.Tensor, mma_type: str | None = None + ) -> torch.Tensor: + """ + Convert tilert weights and scales to tilert_144sm input format. + + Args: + mat_in: tilert weights + mat_scale_in: tilert scales + mma_type: MMA type, None,"16x32" or "16x16" + Returns: + tilert_144sm weights and scales + """ + exp_num = mat_in.shape[0] + assert mat_in.shape == (exp_num, 512, 7168) + assert mat_scale_in.shape == (exp_num, 4, 64) + weights_trt = mat_in.reshape(exp_num, 128, 4, 7168) + weights_w1 = weights_trt[:, :, :2].reshape(exp_num, 256, 7168) + weights_w3 = weights_trt[:, :, 2:].reshape(exp_num, 256, 7168) + # to 16x1024 blocks + weights_w1 = weights_w1.reshape(exp_num, 16, 16, 7, 1024).transpose(2, 3) + weights_w3 = weights_w3.reshape(exp_num, 16, 16, 7, 1024).transpose(2, 3) + if mma_type == "16x32": + weights_w1 = weights_w1.reshape(exp_num, 16, 7, 16, 32, 32).transpose(3, 4) + weights_w1 = ExpertSelectUpGateSiLUWeightsConverter._swizzle_mma_16x32(weights_w1) + weights_w1 = weights_w1.reshape(exp_num, 16, 7, 16, 1024) + weights_w3 = weights_w3.reshape(exp_num, 16, 7, 16, 32, 32).transpose(3, 4) + weights_w3 = ExpertSelectUpGateSiLUWeightsConverter._swizzle_mma_16x32(weights_w3) + weights_w3 = weights_w3.reshape(exp_num, 16, 7, 16, 1024) + elif mma_type == "16x16": + weights_w1 = weights_w1.reshape(exp_num, 16, 7, 16, 64, 16).transpose(3, 4) + weights_w1 = ExpertSelectUpGateSiLUWeightsConverter._swizzle_mma_16x16(weights_w1) + weights_w1 = weights_w1.reshape(exp_num, 16, 7, 16, 1024) + weights_w3 = weights_w3.reshape(exp_num, 16, 7, 16, 64, 16).transpose(3, 4) + weights_w3 = ExpertSelectUpGateSiLUWeightsConverter._swizzle_mma_16x16(weights_w3) + weights_w3 = weights_w3.reshape(exp_num, 16, 7, 16, 1024) + + weights = torch.cat([weights_w1, weights_w3], dim=3) + assert weights.shape == (exp_num, 16, 7, 32, 1024) + weights = weights.reshape(exp_num, 16, 7, 32 * 1024) + + # For scales, first unswizzle + scales_unswizzled = torch.zeros(exp_num, 4, 56) + for i in range(64): + if ((i % 8) * 8 + i // 8) < 56: + scales_unswizzled[..., ((i % 8) * 8 + i // 8)] = mat_scale_in[..., i] + scales_unswizzled = scales_unswizzled.reshape(exp_num, 2, 2, 56) + + scales_w1 = scales_unswizzled[:, :, :1].repeat(1, 1, 8, 1).reshape(exp_num, 16, 1, 7, 8) + scales_w1 = scales_w1.transpose(2, 3) + scales_w3 = scales_unswizzled[:, :, 1:].repeat(1, 1, 8, 1).reshape(exp_num, 16, 1, 7, 8) + scales_w3 = scales_w3.transpose(2, 3) + scales = torch.cat([scales_w1, scales_w3], dim=3) + assert scales.shape == (exp_num, 16, 7, 2, 8) + scales = ( + scales.reshape(exp_num, 16, 7, 2 * 8).to(torch.bfloat16).view(dtype=torch.float8_e4m3fn) + ) + weights_and_scales = torch.zeros( + exp_num, 16, 7, 32 * 1024 + 128, dtype=torch.float8_e4m3fn, device=mat_in.device + ) + weights_and_scales[:, :, :, : 32 * 1024].copy_(weights) + weights_and_scales[:, :, :, 32 * 1024 : 32 * 1024 + 32].copy_(scales) + return weights_and_scales + + @staticmethod + def tilert_to_tilert_144sm_mma( + mat_in: torch.Tensor, mat_scale_in: torch.Tensor, mma_type: str = "16x32" + ) -> torch.Tensor: + """ + Convert tilert weights and scales to tilert_144sm_mma input format. + + Args: + mat_in: tilert weights + mat_scale_in: tilert scales + Returns: + tilert_144sm weights and scales + """ + return ExpertSelectUpGateSiLUWeightsConverter.tilert_to_tilert_144sm( + mat_in, mat_scale_in, mma_type + ) + + def convert_to_mma( + self, weights_list: list[torch.Tensor], algorithm: str = "fp8mma" + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert the weights to mma format. + + Args: + weights: List of weights. + + Returns: + Tuple of weights. + """ + args = self.model_args + dim = args.dim + pages = dim // 1024 # 6 for GLM5, 7 for DS v3.2 + dim_scale_dim = dim // args.block_size + with torch.inference_mode(): + # w1: gate, w3: up + bias_or_gamma, weights_w1, scales_w1, weights_w3, scales_w3 = weights_list + exp_num = weights_w1.shape[0] + # to 16x1024 blocks + weights_w1 = weights_w1.reshape(exp_num, 16, 16, pages, 1024).transpose(2, 3) + weights_w3 = weights_w3.reshape(exp_num, 16, 16, pages, 1024).transpose(2, 3) + # to 16x32 blocks and swizzle + if algorithm == "fp8mma": + weights_w1 = weights_w1.reshape(exp_num, 16, pages, 16, 32, 32).transpose(3, 4) + weights_w1 = self._swizzle_qmma_16x32(weights_w1) + weights_w1 = weights_w1.reshape(exp_num, 16, pages, 16, 1024) + weights_w3 = weights_w3.reshape(exp_num, 16, pages, 16, 32, 32).transpose(3, 4) + weights_w3 = self._swizzle_qmma_16x32(weights_w3) + weights_w3 = weights_w3.reshape(exp_num, 16, pages, 16, 1024) + elif algorithm == "fp16mma": + weights_w1 = weights_w1.reshape(exp_num, 16, pages, 16, 64, 16).transpose(3, 4) + weights_w1 = self._swizzle_mma_16x16(weights_w1) + weights_w1 = weights_w1.reshape(exp_num, 16, pages, 16, 1024) + weights_w3 = weights_w3.reshape(exp_num, 16, pages, 16, 64, 16).transpose(3, 4) + weights_w3 = self._swizzle_mma_16x16(weights_w3) + weights_w3 = weights_w3.reshape(exp_num, 16, pages, 16, 1024) + else: + raise ValueError(f"Unsupported algorithm: {algorithm}") + # concat w1 and w3 + weights: torch.Tensor = torch.cat([weights_w1, weights_w3], dim=3) + assert weights.shape == (exp_num, 16, pages, 32, 1024) + weights = weights.reshape(exp_num, 16, pages, 32 * 1024) + + scales_w1 = ( + scales_w1.reshape(exp_num, 2, 1, dim_scale_dim) + .repeat(1, 1, 8, 1) + .reshape(exp_num, 16, 1, pages, 8) + ) + scales_w1 = scales_w1.transpose(2, 3) + scales_w3 = ( + scales_w3.reshape(exp_num, 2, 1, dim_scale_dim) + .repeat(1, 1, 8, 1) + .reshape(exp_num, 16, 1, pages, 8) + ) + scales_w3 = scales_w3.transpose(2, 3) + scales = torch.cat([scales_w1, scales_w3], dim=3) + assert scales.shape == (exp_num, 16, pages, 2, 8) + + if self.model_args.arch_name == "glm_5": + if scales.dtype != torch.float32: + print( + "Warning: ExpertSelectUpGateSiLUWeightsConverter: " + + f"scales.dtype: {scales.dtype} " + + "is not float32, convert to float32." + ) + scales = scales.to(torch.float32) + else: # DS v3.2, use bfloat16 for scales + scales = scales.to(torch.bfloat16) + + scales = scales.reshape(exp_num, 16, pages, 2 * 8).view(dtype=torch.float8_e4m3fn) + + weights_and_scales = torch.zeros( + exp_num, + 16, + pages, + 32 * 1024 + 128, + dtype=torch.float8_e4m3fn, + device=weights_w1.device, + ) + weights_and_scales[:, :, :, : 32 * 1024].copy_(weights) + weights_and_scales[:, :, :, 32 * 1024 : 32 * 1024 + scales.shape[-1]].copy_(scales) + + return bias_or_gamma.float(), weights_and_scales.contiguous() + + def convert_to_fp8mma( + self, weights_list: list[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert the weights to fp8mma format. + + Args: + weights: List of weights. + + Returns: + Tuple of weights. + """ + return self.convert_to_mma(weights_list, "fp8mma") + + def convert_to_fp16mma( + self, weights_list: list[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert the weights to fp16mma format. + + Args: + weights: List of weights. + + Returns: + Tuple of weights. + """ + return self.convert_to_mma(weights_list, "fp16mma") + + +class ExpertSelectUpGateSiLU(TileRTModule): + """ExpertSelectUpGateSiLU module""" + + def __init__( + self, + model_args: ModelArgs, + num_devices: int, + device_id: int = 0, + ref_weights_alias: ExpertSelectUpGateSiLURefWeightsAlias | None = None, + tilert_weights_alias: ExpertSelectUpGateSiLUTilertWeightsAlias | None = None, + algorithm: ExpertSelectUpGateSiLUAlgorithm = ExpertSelectUpGateSiLUAlgorithm.FP8MMA, + ): + super().__init__( + self.__class__.__name__, + model_args=model_args, + num_devices=num_devices, + device_id=device_id, + ) + + self.arch_name = self.model_args.arch_name + self.dim = self.model_args.dim + + self.n_activated_experts = self.model_args.n_activated_experts + self.n_routed_experts = self.model_args.n_routed_experts + self.n_shared_experts = self.model_args.n_shared_experts + self.moe_inter_dim = self.model_args.moe_inter_dim + self.n_expert_groups = self.model_args.n_expert_groups + self.n_limited_groups = self.model_args.n_limited_groups + self.route_scale = self.model_args.route_scale + self.block_size = self.model_args.block_size + self.algorithm = algorithm + + self.tilert_weights_alias = ( + tilert_weights_alias + if tilert_weights_alias is not None + else ExpertSelectUpGateSiLUTilertWeightsAlias() + ) + self.ref_weights_alias = ( + ref_weights_alias + if ref_weights_alias is not None + else ExpertSelectUpGateSiLURefWeightsAlias( + key_prefix="mlp", n_routed_experts=self.n_routed_experts + ) + ) + + # reference weights + self.ref_bias: torch.Tensor | None = None + self.ref_gate: torch.Tensor | None = None + self.ref_up: torch.Tensor | None = None + + # tilert weights + self.tilert_bias: torch.Tensor | None = None + self.tilert_weights: torch.Tensor | None = None + # for compatibility, to be removed in the future + self.tilert_scales = torch.zeros(1, dtype=torch.bfloat16, device=torch.device("cuda")) + + # tilert vars + self.hidden_out: torch.Tensor | None = None + self.expert_probs: torch.Tensor | None = None + self.expert_indices: torch.Tensor | None = None + + self.profile_logs: torch.Tensor | None = None + self.is_init = False + + self._tensor_alias = self.tilert_weights_alias() + self._tilert_tensor_alias = [ + self.tilert_weights_alias.exp_bias, + "exp_upgate_weights", + "exp_upgate_scales", + ] + + @property + def tensor_alias(self) -> list[str]: + return self._tensor_alias + + @property + def tilert_tensor_alias(self) -> list[str]: + """Output weight names for get_weights_list (backward compat).""" + return self._tilert_tensor_alias + + def get_weights_list(self) -> list[torch.Tensor]: + """ + Get the weights list. + + Returns: + List of weights. + """ + return [self.tilert_bias, self.tilert_weights, self.tilert_scales] + + @staticmethod + def process_gate_up_weights( + key_prefix: str, # e.g. mlp.shared_experts or mlp.experts.{id} + weights_hf: dict[str, torch.Tensor], + num_devices: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + gate_proj_weight_key = f"{key_prefix}.gate_proj.weight" + gate_proj_scale_key = f"{key_prefix}.gate_proj.weight_scale_inv" + up_proj_weight_key = f"{key_prefix}.up_proj.weight" + up_proj_scale_key = f"{key_prefix}.up_proj.weight_scale_inv" + + gate_proj_weight = weights_hf[gate_proj_weight_key] + gate_proj_scale = weights_hf[gate_proj_scale_key] + up_proj_weight = weights_hf[up_proj_weight_key] + up_proj_scale = weights_hf[up_proj_scale_key] + dim = gate_proj_weight.shape[-1] + in_dim = gate_proj_weight.shape[-2] + scale_dim = gate_proj_scale.shape[-1] + in_scale_dim = gate_proj_scale.shape[-2] + in_dim_per_device = in_dim // num_devices + in_scale_dim_per_device = in_scale_dim // num_devices + gate_proj_weight = gate_proj_weight.reshape(num_devices, 1, in_dim_per_device, dim) + gate_proj_scale = gate_proj_scale.reshape( + num_devices, 1, in_scale_dim_per_device, scale_dim + ) + up_proj_weight = up_proj_weight.reshape(num_devices, 1, in_dim_per_device, dim) + up_proj_scale = up_proj_scale.reshape(num_devices, 1, in_scale_dim_per_device, scale_dim) + return gate_proj_weight, gate_proj_scale, up_proj_weight, up_proj_scale + + def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Device sharding: ref state dict -> tilert sharded tensors (num_devices, ...). + + Args: + weights_map: State dict keyed by ref_weights_alias(). + + Returns: + Dict keyed by tilert_weights_alias() with (num_devices, ...) tensors. + """ + ref_alias = self.ref_weights_alias + key_prefix = ref_alias.key_prefix + + bias_key = f"{key_prefix}.gate.e_score_correction_bias" + bias = weights_map[bias_key] + bias = bias[None, :].repeat(self.num_devices, 1) + + gate_weights_list = [] + gate_scales_list = [] + up_weights_list = [] + up_scales_list = [] + assert self.n_shared_experts == 1, "Only one shared expert is supported" + exp_prefix = f"{key_prefix}.shared_experts" + gate_weights, gate_scales, up_weights, up_scales = self.process_gate_up_weights( + exp_prefix, weights_map, self.num_devices + ) + gate_weights_list.append(gate_weights) + gate_scales_list.append(gate_scales) + up_weights_list.append(up_weights) + up_scales_list.append(up_scales) + + for exp_id in range(self.n_routed_experts): + exp_prefix = f"{key_prefix}.experts.{exp_id}" + gate_weights, gate_scales, up_weights, up_scales = self.process_gate_up_weights( + exp_prefix, weights_map, self.num_devices + ) + gate_weights_list.append(gate_weights) + gate_scales_list.append(gate_scales) + up_weights_list.append(up_weights) + up_scales_list.append(up_scales) + + gate_weights = torch.cat(gate_weights_list, dim=1) + gate_scales = torch.cat(gate_scales_list, dim=1) + up_weights = torch.cat(up_weights_list, dim=1) + up_scales = torch.cat(up_scales_list, dim=1) + tilert_alias = self.tilert_weights_alias + return { + tilert_alias.exp_bias: bias, + tilert_alias.exp_gate_weights: gate_weights, + tilert_alias.exp_gate_scales: gate_scales, + tilert_alias.exp_up_weights: up_weights, + tilert_alias.exp_up_scales: up_scales, + } + + def init_reference_weights( + self, + state_dict: dict[str, torch.Tensor], + device_id: int | None = None, + ) -> None: + """ + Initialize the reference weights. + + Args: + state_dict: State dict keyed by ref_weights_alias(). + device_id: Device ID; defaults to self.device_id. + """ + did = self.device_id if device_id is None else device_id + sharded = self.device_sharding(state_dict) + + tilert_alias = self.tilert_weights_alias + bias = sharded[tilert_alias.exp_bias][did] + gate_weights = sharded[tilert_alias.exp_gate_weights][did] + gate_scales = sharded[tilert_alias.exp_gate_scales][did] + up_weights = sharded[tilert_alias.exp_up_weights][did] + up_scales = sharded[tilert_alias.exp_up_scales][did] + + self.ref_bias = bias + ref_gate_list = [ + weight_dequant(gate_weights[i], gate_scales[i]) for i in range(gate_weights.shape[0]) + ] + ref_up_list = [ + weight_dequant(up_weights[i], up_scales[i]) for i in range(up_weights.shape[0]) + ] + self.ref_gate = torch.stack(ref_gate_list, dim=0) + self.ref_up = torch.stack(ref_up_list, dim=0) + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """ + Initialize the tilert weights. + + Args: + state_dict: State dict keyed by tilert_weights_alias() (per-device). + """ + assert self.algorithm is not None, "Algorithm is not set" + weights_list = [state_dict[alias] for alias in self.tilert_weights_alias()] + self.tilert_bias, self.tilert_weights = ExpertSelectUpGateSiLUWeightsConverter( + self.model_args, self.num_devices + ).dispatch(self.algorithm, weights_list) + + def init_tilert_vars(self, batch_size: int, seq_len: int, device: str = "cuda") -> None: + """ + Initialize the tilert variables. + + Args: + batch_size: Batch size. + seq_len: Sequence length. + """ + # tilert vars + self.hidden_out = torch.zeros( + ( + batch_size, + seq_len, + self.n_activated_experts + self.n_shared_experts, + self.moe_inter_dim // self.num_devices, + ), + dtype=torch.bfloat16, + device=device, + ) + self.expert_probs = torch.zeros( + (batch_size, seq_len, self.n_activated_experts), + dtype=torch.float32, + device=device, + ) + self.expert_indices = torch.zeros( + (batch_size, seq_len, self.n_activated_experts), + dtype=torch.int32, + device=device, + ) + + self.profile_logs = get_profile_log_tensor(device=device) + self.is_init = True + + def init_random_weights(self, device: str = "cuda") -> None: + """ + Initialize the random weights. + + Returns: + None + """ + bias = torch.randn(self.n_routed_experts, dtype=torch.float32, device=device) + gate_weights = [ + torch.randn(self.moe_inter_dim, self.dim, dtype=torch.bfloat16, device=device).to( + torch.float8_e4m3fn + ) + for _ in range(self.n_routed_experts + 1) + ] + up_weights = [ + torch.randn(self.moe_inter_dim, self.dim, dtype=torch.bfloat16, device=device).to( + torch.float8_e4m3fn + ) + for _ in range(self.n_routed_experts + 1) + ] + moe_inter_dim_scale_dim = self.moe_inter_dim // self.block_size + dim_scale_dim = self.dim // self.block_size + scale_dtype = torch.float32 if self.arch_name == "glm_5" else torch.bfloat16 + gate_scales = [ + torch.randn(moe_inter_dim_scale_dim, dim_scale_dim, dtype=scale_dtype, device=device) + for _ in range(self.n_routed_experts + 1) + ] + up_scales = [ + torch.randn(moe_inter_dim_scale_dim, dim_scale_dim, dtype=scale_dtype, device=device) + for _ in range(self.n_routed_experts + 1) + ] + tensor_list = [ + bias, + *gate_weights, + *up_weights, + *gate_scales, + *up_scales, + ] + ref_state_dict = dict(zip(self.ref_weights_alias(), tensor_list)) + self.init_reference_weights(ref_state_dict) + sharded = self.device_sharding(ref_state_dict) + per_device_state = {k: v[self.device_id] for k, v in sharded.items()} + self.init_tilert_weights(per_device_state) + + def _ref_expert_select_ds(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + flatten_dim = np.prod(scores.size()[:-1]) + scores = scores.sigmoid() + original_scores = scores + if self.ref_bias is not None: + scores = scores + self.ref_bias + + if self.n_expert_groups > 1: + scores = scores.view(flatten_dim, self.n_expert_groups, -1) + if self.ref_bias is None: + group_scores = scores.amax(dim=-1) + else: + group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) + indices = group_scores.topk(self.n_limited_groups, dim=-1)[1] + mask = scores.new_ones(flatten_dim, self.n_expert_groups, dtype=torch.bool).scatter_( + 1, indices, False + ) + scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1) + indices = torch.topk(scores, self.n_activated_experts, dim=-1)[1] + indices = indices.view(*original_scores.shape[:-1], self.n_activated_experts) + weights = original_scores.gather(-1, indices) + weights /= weights.sum(dim=-1, keepdim=True) + weights *= self.route_scale + return weights, indices + + def _ref_expert_select_glm5(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # flatten_dim = np.prod(scores.size()[:-1]) + scores = scores.sigmoid() + original_scores = scores + if self.ref_bias is not None: + scores = scores + self.ref_bias + indices = torch.topk(scores, self.n_activated_experts, dim=-1)[1] + indices = indices.view(*original_scores.shape[:-1], self.n_activated_experts) + weights = original_scores.gather(-1, indices) + weights /= weights.sum(dim=-1, keepdim=True) + weights *= self.route_scale + return weights, indices + + def golden_forward( + self, + x_in: torch.Tensor, + scores: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert self.ref_gate is not None + assert self.ref_up is not None + bsz = x_in.shape[0] + seq_len = x_in.shape[1] + assert bsz == 1 + if self.arch_name == "deepseek_v3_2": + weights, indices = self._ref_expert_select_ds(scores) + elif self.arch_name == "glm_5": + weights, indices = self._ref_expert_select_glm5(scores) + else: + raise ValueError(f"Unsupported architecture: {self.arch_name}") + hidden_out_list = [] + for s in range(seq_len): + # ref up-gate silu + hidden_out_w1_list = [] + hidden_out_w3_list = [] + hidden_out_w1_shared = x_in[0, s].float() @ self.ref_gate[0].float().T + hidden_out_w3_shared = x_in[0, s].float() @ self.ref_up[0].float().T + hidden_out_w1_list.append(hidden_out_w1_shared) + hidden_out_w3_list.append(hidden_out_w3_shared) + ref_gate_sel = self.ref_gate[1:][indices[0, s]] + ref_up_sel = self.ref_up[1:][indices[0, s]] + for i in range(self.n_activated_experts): + hidden_out_w1_sel = x_in[0, s].float() @ ref_gate_sel[i].float().T + hidden_out_w3_sel = x_in[0, s].float() @ ref_up_sel[i].float().T + hidden_out_w1_list.append(hidden_out_w1_sel) + hidden_out_w3_list.append(hidden_out_w3_sel) + hidden_out_w1 = torch.stack(hidden_out_w1_list, dim=0) + hidden_out_w3 = torch.stack(hidden_out_w3_list, dim=0) + hidden_out = F.silu(hidden_out_w1.float()) * hidden_out_w3.float() + hidden_out = hidden_out.to(torch.bfloat16) + hidden_out_list.append(hidden_out) + hidden_out = torch.stack(hidden_out_list, dim=0) + hidden_out = hidden_out[None, ...] + return hidden_out, weights, indices + + def tilert_forward( + self, + x_in: torch.Tensor, + scores: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert self.algorithm is not None, "Algorithm is not set" + expert_select_up_gate_silu( + x_in, + scores, + self.tilert_bias, + self.tilert_weights, + self.hidden_out, + self.expert_probs, + self.expert_indices, + self.profile_logs, + self.algorithm.value, + ) + if self.flag_enable_profiling_log: + parse_profile_log_tensor( + self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)] + ) + return self.hidden_out, self.expert_probs, self.expert_indices diff --git a/python/models/deepseek_v3_2/ops/expert_select.py b/python/models/deepseek_v3_2/ops/expert_select.py new file mode 100644 index 0000000..6a16d76 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/expert_select.py @@ -0,0 +1,49 @@ +"""ExpertSelect operation module.""" + +import torch + +__all__ = [ + "expert_select", + "expert_select_one_stage", +] + + +def expert_select( + scores_in: torch.Tensor, + bias_in: torch.Tensor, + expert_probs_out: torch.Tensor, + expert_indices_out: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """ + Expert Select operation. + + Original two-stage expert select operation used in DeepSeek V3.2. + """ + torch.ops.tilert.expert_select_op( + scores_in, + bias_in, + expert_probs_out, + expert_indices_out, + profile_logs, + ) + + +def expert_select_one_stage( + scores_in: torch.Tensor, + bias_in: torch.Tensor, + expert_probs_out: torch.Tensor, + expert_indices_out: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """Expert Select operation. + + Modified one-stage expert select operation used in Kimi and GLM. + """ + torch.ops.tilert.expert_select_glm5_op( + scores_in, + bias_in, + expert_probs_out, + expert_indices_out, + profile_logs, + ) diff --git a/python/models/deepseek_v3_2/ops/flash_sparse_mla.py b/python/models/deepseek_v3_2/ops/flash_sparse_mla.py new file mode 100644 index 0000000..deebddc --- /dev/null +++ b/python/models/deepseek_v3_2/ops/flash_sparse_mla.py @@ -0,0 +1,265 @@ +"""Flash Sparse MLA operation module.""" + +import math + +import torch + +from tilert.models.base import TileRTModule +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.profiler.utils import parse_profile_log_tensor +from tilert.utils import get_profile_log_tensor + +__all__ = [ + "flash_sparse_mla", + "FlashSparseMLACombine", +] + + +def flash_sparse_mla( + query: torch.Tensor, + query_pe: torch.Tensor, + key_value: torch.Tensor, + key_pe: torch.Tensor, + indices: torch.Tensor, + cur_pos: torch.Tensor, + output: torch.Tensor, + profile_logs: torch.Tensor, + split_size: int = 64, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Flash Sparse MLA operation for GLM5. + + Args: + query: Query tensor. (bs, seqlen, heads, dim) + query_pe: Query position embedding tensor. (bs, seqlen, heads, pe_dim) + key_value: Key-value tensor. (bs, seqlen_kv, dim) + key_pe: Key position embedding tensor. (bs, seqlen_kv, pe_dim) + indices: Indices tensor. (bs, seqlen, topk) + cur_pos: cur_pos tensor. (1) + output: Output tensor. + profile_logs: Profile logs tensor. + split_size: Number of splits. + """ + batch, seqlen, heads, hidden_dim = query.shape + if split_size != 64: + raise ValueError( + "The current implementation of flash_sparse_mla_op only supports split_size=64" + ) + if batch != 1: + raise ValueError("The current implementation of flash_sparse_mla_op only supports batch=1") + if seqlen > 4: + raise ValueError( + "The current implementation of flash_sparse_mla_op only supports seqlen<=4" + ) + + seqlen_kv = key_value.shape[1] + index_len = indices.shape[-1] + if index_len > seqlen_kv: + raise ValueError("index_len must be less than or equal to seqlen_kv") + + device = query.device + acc_type = torch.float32 + + dim = key_value.shape[-1] + max_num_splits = 32 # topk / split_size = 2048/64 + + lse = torch.empty((batch, seqlen, heads), device=device, dtype=acc_type) + lse_acc = torch.empty((batch, seqlen, heads, max_num_splits), device=device, dtype=acc_type) + output_acc = torch.empty( + batch, seqlen, heads, max_num_splits, dim, device=device, dtype=acc_type + ) + + if heads == 16: + torch.ops.tilert.flash_sparse_mla_op( + query, + query_pe, + key_value, + key_pe, + indices, + cur_pos, + output, + output_acc, + lse, + lse_acc, + profile_logs, + split_size, + ) + elif heads == 8: + torch.ops.tilert.flash_sparse_mla_glm5_op( + query, + query_pe, + key_value, + key_pe, + indices, + cur_pos, + output, + output_acc, + lse, + lse_acc, + profile_logs, + split_size, + ) + else: + raise ValueError(f"Unsupported heads: {heads}") + return lse, lse_acc, output_acc + + +class FlashSparseMLACombine(TileRTModule): + """Flash Sparse MLA combine module; no weights, uses model_args for scale and config.""" + + def __init__( + self, + model_args: ModelArgs, + num_devices: int, + layer_idx: int = 0, + ): + super().__init__( + type(self).__name__, + model_args=model_args, + num_devices=num_devices, + layer_idx=layer_idx, + ) + self.tilert_tensor_alias: list[str] = [] + self.ref_tensor_alias: list[str] = [] + + scale = (model_args.qk_nope_head_dim + model_args.qk_rope_head_dim) ** -0.5 + if model_args.rope_factor is None: + mscale = 1.0 + else: + mscale = 0.1 * math.log(model_args.rope_factor) + 1.0 + self.softmax_scale = scale * mscale * mscale + + self.profile_logs = get_profile_log_tensor() + + def init_reference_weights( + self, state_dict: dict[str, torch.Tensor], device_id: int = 0 + ) -> None: + del state_dict, device_id + self.is_ref_weights_init = True + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + del state_dict + self.is_tilert_weights_init = True + + def init_random_weights(self) -> None: + self.is_ref_weights_init = True + self.is_tilert_weights_init = True + + def init_tilert_vars(self, batch_size: int, seq_len: int) -> None: + del batch_size, seq_len + self.profile_logs = get_profile_log_tensor() + self.is_var_init = True + + def golden_forward( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_cache: torch.Tensor, + pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + cur_pos: torch.Tensor, + ) -> torch.Tensor: + """Flash Sparse MLA golden version. + + Args: + q_nope: Query tensor. (bs, seqlen, heads, dim) + q_pe: Query position embedding tensor. (bs, seqlen, heads, pe_dim) + kv_cache: Key-value tensor. (bs, seqlen_kv, dim) + pe_cache: Key position embedding tensor. (bs, seqlen_kv, pe_dim) + topk_indices: Indices tensor. (bs, seqlen, topk) + cur_pos: cur_pos tensor. (1) + """ + batch_size = q_nope.shape[0] + seqlen = q_nope.shape[1] + seqlen_kv = kv_cache.shape[1] + + start_pos = int(cur_pos.item()) + mask = ( + torch.full((seqlen, seqlen_kv), float("-inf")).triu_(start_pos + 1) + if seqlen > 1 + else None + ) + + scores = ( + torch.einsum("bshc,btc->bsht", q_nope.float(), kv_cache.float()) + + torch.einsum("bshr,btr->bsht", q_pe.float(), pe_cache.float()) + ) * self.softmax_scale + index_mask = torch.full( + (batch_size, seqlen, seqlen_kv), float("-inf"), device=q_nope.device + ).scatter_(-1, topk_indices, 0) + if mask is not None: + index_mask += mask + + scores += index_mask.unsqueeze(2) + scores = scores.softmax(dim=-1, dtype=torch.float32) + return torch.einsum("bsht,btc->bshc", scores.to(torch.bfloat16), kv_cache) + + def tilert_forward( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_cache: torch.Tensor, + pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + cur_pos: torch.Tensor, + ) -> torch.Tensor: + """Flash Sparse MLA tilert version. + + Args: + q_nope: Query tensor. (bs, seqlen, heads, dim) + q_pe: Query position embedding tensor. (bs, seqlen, heads, pe_dim) + kv_cache: Key-value tensor. (bs, seqlen_kv, dim) + pe_cache: Key position embedding tensor. (bs, seqlen_kv, pe_dim) + topk_indices: Indices tensor. (bs, seqlen, topk) + cur_pos: cur_pos tensor. (1) + """ + batch_size, seqlen, heads, dim = q_nope.shape + v_dim = kv_cache.shape[-1] + + topk_indices = topk_indices.to(torch.int32) + topk_indices = topk_indices[..., : kv_cache.shape[1]] + device = q_nope.device + if any(t.device != device for t in (q_pe, kv_cache, pe_cache, topk_indices, cur_pos)): + raise RuntimeError( + "flash_sparse_mla inputs must be on the same device: " + f"q_nope={device}, q_pe={q_pe.device}, kv_cache={kv_cache.device}, " + f"pe_cache={pe_cache.device}, topk_indices={topk_indices.device}, " + f"cur_pos={cur_pos.device}" + ) + if self.profile_logs is not None and self.profile_logs.device != device: + self.profile_logs = get_profile_log_tensor(device_index=device.index, device=device) + output = torch.zeros( + (batch_size, seqlen, heads, v_dim), dtype=torch.bfloat16, device=device + ) + flash_sparse_mla( + q_nope, + q_pe, + kv_cache, + pe_cache, + topk_indices, + cur_pos, + output, + self.profile_logs, + ) + if self.flag_enable_profiling_log: + # TODO: bug fix for this + torch.cuda.synchronize() + parse_profile_log_tensor( + self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)] + ) + return output + + def to_tilert_weights(self) -> None: + raise NotImplementedError("to_tilert_weights not implemented") + + def __call__( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_cache: torch.Tensor, + pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + cur_pos: torch.Tensor, + ) -> torch.Tensor: + if self.flag_enable_tilert: + return self.tilert_forward(q_nope, q_pe, kv_cache, pe_cache, topk_indices, cur_pos) + return self.golden_forward(q_nope, q_pe, kv_cache, pe_cache, topk_indices, cur_pos) diff --git a/python/models/deepseek_v3_2/ops/head_proj.py b/python/models/deepseek_v3_2/ops/head_proj.py new file mode 100644 index 0000000..5ab8b77 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/head_proj.py @@ -0,0 +1,22 @@ +"""HeadProj operation module.""" + +import torch + +__all__ = [ + "head_proj", +] + + +def head_proj( + hidden_in: torch.Tensor, + weight_in: torch.Tensor, + logits_out: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """Head Projection operation.""" + torch.ops.tilert.head_proj_op( + hidden_in, + weight_in, + logits_out, + profile_logs, + ) diff --git a/python/models/deepseek_v3_2/ops/layernorm_rope_rotate.py b/python/models/deepseek_v3_2/ops/layernorm_rope_rotate.py new file mode 100644 index 0000000..b1bd0a7 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/layernorm_rope_rotate.py @@ -0,0 +1,225 @@ +"""Layernorm_rope_rotate operation module.""" + +from dataclasses import dataclass + +import torch +import torch.nn.functional as F + +from tilert.models.base import TileRTModule +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.models.deepseek_v3_2.ops.rotate import rotate_activation +from tilert.models.utils import apply_rotary_emb +from tilert.profiler.utils import parse_profile_log_tensor +from tilert.utils import get_profile_log_tensor + +__all__ = [ + "layernorm_rope_rotate", + "LayerNormRoPERotate", + "LayerNormRoPERotateRefWeightsAlias", + "LayerNormRoPERotateTilertWeightsAlias", +] + + +def layernorm_rope_rotate( + input_raw: torch.Tensor, + cur_pos: torch.Tensor, + k_cache_raw: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + freqs_cis: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """ + Layernorm_rope_rotate operation. + + Layernorm_rope_rotate the input tensor `input_raw` and stores the result in `k_cache_raw`. + + Args: + input_raw (torch.Tensor): The input tensor. + cur_pos (torch.Tensor): The current position tensor. + k_cache_raw (torch.Tensor): The output tensor where the result will be stored. + weight (torch.Tensor): The weight tensor. + bias (torch.Tensor): The bias tensor. + freqs_cis (torch.Tensor): The frequency tensor. + profile_logs (torch.Tensor): Tensor for storing profiling logs. + + Returns: + None + """ + if input_raw.dtype != torch.bfloat16: + raise ValueError("input must be a bfloat16 tensor.") + if cur_pos.dtype != torch.int32: + raise ValueError("cur_pos must be a int32 tensor.") + if k_cache_raw.dtype != torch.bfloat16: + raise ValueError("k_cache must be a bfloat16 tensor.") + + if weight.dtype != torch.float32: + raise ValueError("weight must be a float32 tensor.") + + if bias.dtype != torch.float32: + raise ValueError("bias must be a float32 tensor.") + + if freqs_cis.dtype != torch.float32: + raise ValueError("freqs_cis must be a float32 tensor.") + + batch, seq, dim = input_raw.shape + if dim != 128: + raise ValueError("dim must be 128, as we precompute scale inner kernel") + if batch != 1: + raise ValueError("batch must be 1 in this version") + + torch.ops.tilert.layernorm_rope_rotate_op( + input_raw, cur_pos, k_cache_raw, weight, bias, freqs_cis, profile_logs + ) + + +@dataclass +class LayerNormRoPERotateRefWeightsAlias: + """Reference weights alias for LayerNormRoPERotate.""" + + k_weight = "self_attn.indexer.k_norm.weight" + k_bias = "self_attn.indexer.k_norm.bias" + + @property + def ref_tensor_alias(self) -> list[str]: + return [self.k_weight, self.k_bias] + + def __call__(self) -> list[str]: + return self.ref_tensor_alias + + +@dataclass +class LayerNormRoPERotateTilertWeightsAlias: + """TileRT weights alias for LayerNormRoPERotate.""" + + k_weight = "k_weights" + k_bias = "k_bias" + + @property + def tilert_tensor_alias(self) -> list[str]: + return [self.k_weight, self.k_bias] + + def __call__(self) -> list[str]: + return self.tilert_tensor_alias + + +class LayerNormRoPERotate(TileRTModule): + """LayerNormRoPERotate module: LayerNorm + RoPE + rotate on K indexer output.""" + + def __init__( + self, + model_args: ModelArgs, + num_devices: int, + device_id: int, + ref_weights_alias: LayerNormRoPERotateRefWeightsAlias | None = None, + ): + super().__init__( + self.__class__.__name__, + model_args=model_args, + num_devices=num_devices, + device_id=device_id, + ) + + self.tilert_weights_alias = LayerNormRoPERotateTilertWeightsAlias() + self.ref_weights_alias = ( + ref_weights_alias + if ref_weights_alias is not None + else LayerNormRoPERotateRefWeightsAlias() + ) + + self.rope_head_dim = self.model_args.qk_rope_head_dim + self.head_dim = self.model_args.index_head_dim + + self.ref_weight: torch.Tensor | None = None + self.ref_bias: torch.Tensor | None = None + self.tilert_weight: torch.Tensor | None = None + self.tilert_bias: torch.Tensor | None = None + self.output: torch.Tensor | None = None + self.profile_logs: torch.Tensor | None = None + + def get_weights_list(self) -> list[torch.Tensor]: + return [self.tilert_weight, self.tilert_bias] + + def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Device sharding: replicate weight and bias for each device. + + Args: + weights_map: Map from ref weight alias to tensor. + + Returns: + Map from tilert weight alias to (num_devices, ...) tensors. + """ + k_weight = weights_map[self.ref_weights_alias.k_weight][None, ...].repeat( + self.num_devices, 1 + ) + k_bias = weights_map[self.ref_weights_alias.k_bias][None, ...].repeat(self.num_devices, 1) + return { + self.tilert_weights_alias.k_weight: k_weight, + self.tilert_weights_alias.k_bias: k_bias, + } + + def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + self.ref_weight = state_dict[self.ref_weights_alias.k_weight].contiguous().float() + self.ref_bias = state_dict[self.ref_weights_alias.k_bias].contiguous().float() + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + self.tilert_weight = state_dict[self.tilert_weights_alias.k_weight].contiguous().float() + self.tilert_bias = state_dict[self.tilert_weights_alias.k_bias].contiguous().float() + + def init_random_weights(self) -> None: + ref_weight = torch.ones(self.head_dim, dtype=torch.float32) + ref_bias = torch.zeros(self.head_dim, dtype=torch.float32) + ref_state_dict = dict(zip(self.ref_weights_alias(), [ref_weight, ref_bias])) + self.init_reference_weights(ref_state_dict) + self.init_tilert_weights( + {_k: _v[self.device_id] for _k, _v in self.device_sharding(ref_state_dict).items()} + ) + + def init_tilert_vars(self, batch_size: int, seq_len: int) -> None: + self.cur_pos = torch.tensor([0], dtype=torch.int32) + self.output = torch.zeros((batch_size, seq_len, self.head_dim), dtype=torch.bfloat16) + self.profile_logs = get_profile_log_tensor() + self.is_var_init = True + + def golden_forward(self, idx_k: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + assert self.ref_weight is not None and self.ref_bias is not None + k = F.layer_norm( + idx_k.float(), + (self.head_dim,), + self.ref_weight, + self.ref_bias, + 1e-6, + ).to(idx_k.dtype) + k_pe, k_nope = torch.split( + k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1 + ) + k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis).squeeze(2) + k = torch.cat([k_pe, k_nope], dim=-1) + return rotate_activation(k) + + def tilert_forward(self, idx_k: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + assert self.tilert_weight is not None and self.tilert_bias is not None + assert self.output is not None and self.profile_logs is not None + rope_freqs = ( + torch.view_as_real(freqs_cis).reshape(*freqs_cis.shape[:-1], -1).float().unsqueeze(1) + ) + layernorm_rope_rotate( + idx_k, + self.cur_pos, + self.output, + self.tilert_weight, + self.tilert_bias, + rope_freqs, + self.profile_logs, + ) + if self.flag_enable_profiling_log: + parse_profile_log_tensor( + self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)] + ) + return self.output + + def __call__(self, idx_k: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + if self.flag_enable_tilert: + return self.tilert_forward(idx_k, freqs_cis) + return self.golden_forward(idx_k, freqs_cis) diff --git a/python/models/deepseek_v3_2/ops/projo_wkvb.py b/python/models/deepseek_v3_2/ops/projo_wkvb.py new file mode 100644 index 0000000..618e5ee --- /dev/null +++ b/python/models/deepseek_v3_2/ops/projo_wkvb.py @@ -0,0 +1,283 @@ +"""UnprojOB operation module.""" + +from dataclasses import dataclass +from enum import Enum + +import torch + +from tilert.models.base import TileRTModule, TilertWeightsConverter +from tilert.models.common import init_func, weight_dequant +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.profiler.utils import parse_profile_log_tensor +from tilert.utils import get_profile_log_tensor + +__all__ = [ + "projo_wkvb", + "ProjoWKVb", + "ProjoWKVbAlgorithm", + "ProjoWKVbWeightsConverter", + "ProjoWKVbRefWeightsAlias", + "ProjoWKVbTilertWeightsAlias", +] + + +def projo_wkvb( + o_in: torch.Tensor, + wkv_b_b: torch.Tensor, + wkv_b_scales: torch.Tensor, + output: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """ + Define the UnprojOB operation. + + Args: + o_in: Input tensor. + wkv_b_b: Weight tensor. + wkv_b_scales: Scale tensor. + output: Output tensor. + profile_logs: Profile logs tensor. + """ + # Choose operation based on v_head_dim (128 for deepseek_v3_2, 256 for glm5) + if output.shape[-1] == 128: + torch.ops.tilert.projo_wkvb_op(o_in, wkv_b_b, wkv_b_scales, output, profile_logs) + elif output.shape[-1] == 256: + torch.ops.tilert.proj_ob_glm5_op(o_in, wkv_b_b, wkv_b_scales, output, profile_logs) + else: + raise ValueError(f"Unsupported v_head_dim: {output.shape[-1]}") + + +class ProjoWKVbAlgorithm(Enum): + """ProjoWKVb algorithm""" + + GENERAL = "general" + + +class ProjoWKVbWeightsConverter(TilertWeightsConverter): + def __init__(self, model_args: ModelArgs, num_devices: int): + super().__init__(model_args, num_devices) + + def convert_to_general(self, weights: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + with torch.inference_mode(): + tilert_wkv_b_weights, tilert_wkv_b_scales = weights + + # Input weights are already in the correct shape from device_sharding: + # wkv_b_weights: (n_local_heads, v_head_dim, kv_lora_rank) + # wkv_b_scales: (n_local_heads, v_head_dim // block_size, kv_lora_rank // block_size) + wkv_b_b = tilert_wkv_b_weights.contiguous() + wkv_b_b_scales = tilert_wkv_b_scales.contiguous() + if self.model_args.arch_name == "glm_5": + if wkv_b_b_scales.dtype != torch.float32: + print( + "Warning: ProjoWKVbWeightsConverter: " + + f"wkv_b_b_scales.dtype: {wkv_b_b_scales.dtype} " + + "is not float32, convert to float32." + ) + wkv_b_b_scales = wkv_b_b_scales.to(torch.float32) + else: # DS v3.2, use bfloat16 for wkv_b_b_scales + wkv_b_b_scales = wkv_b_b_scales.to(torch.bfloat16) + + wkv_b_b = wkv_b_b.detach() + wkv_b_b_scales = wkv_b_b_scales.detach() + + return wkv_b_b, wkv_b_b_scales + + +@dataclass +class ProjoWKVbRefWeightsAlias: + """Reference weights alias for ProjoWKVb.""" + + wkv_b_weights = "self_attn.kv_b_proj.weight" + wkv_b_scales = "self_attn.kv_b_proj.weight_scale_inv" + + @property + def ref_tensor_alias(self) -> list[str]: + return [self.wkv_b_weights, self.wkv_b_scales] + + def __call__(self) -> list[str]: + return self.ref_tensor_alias + + +@dataclass +class ProjoWKVbTilertWeightsAlias: + """TileRT weights alias for ProjoWKVb.""" + + wkv_b_weights = "wkv_b2_weights" + wkv_b_scales = "wkv_b2_scales" + + @property + def tilert_tensor_alias(self) -> list[str]: + return [self.wkv_b_weights, self.wkv_b_scales] + + def __call__(self) -> list[str]: + return self.tilert_tensor_alias + + +class ProjoWKVb(TileRTModule): + """ProjoWKVb module: O projection (wkv_b) for output.""" + + def __init__( + self, + model_args: ModelArgs, + num_devices: int, + device_id: int = 0, + ref_weights_alias: ProjoWKVbRefWeightsAlias | None = None, + ): + super().__init__( + self.__class__.__name__, + model_args=model_args, + num_devices=num_devices, + device_id=device_id, + ) + + self.tilert_weights_alias = ProjoWKVbTilertWeightsAlias() + self.ref_weights_alias = ( + ref_weights_alias if ref_weights_alias is not None else ProjoWKVbRefWeightsAlias() + ) + + self.ref_wkv_b: torch.Tensor | None = None + self.tilert_wkv_b_b: torch.Tensor | None = None + self.tilert_wkv_b_b_scales: torch.Tensor | None = None + self.output: torch.Tensor | None = None + self.profile_logs: torch.Tensor | None = None + + self.num_local_heads = self.model_args.n_heads // self.num_devices + + # lora dim and quant block size + self.wkvb_lora_rank = self.model_args.kv_lora_rank + self.wkvb_lora_rank_qsize = self.wkvb_lora_rank // self.model_args.block_size + + self.wkvb_head_dim = self.model_args.qk_nope_head_dim + self.model_args.v_head_dim + self.wkvb_v_head_dim = self.model_args.v_head_dim + left_head_dim = self.wkvb_head_dim % self.model_args.block_size + if left_head_dim != 0: + assert self.model_args.block_size % left_head_dim == 0 + self.head_dim_block_size = left_head_dim + self.head_dim_scale_repeat = self.model_args.block_size // self.head_dim_block_size + else: + self.head_dim_scale_repeat = 1 + self.head_dim_block_size = self.model_args.block_size + self.wkvb_head_qsize = self.wkvb_head_dim // self.head_dim_block_size + self.wkvb_v_head_qsize = self.wkvb_v_head_dim // self.head_dim_block_size + + def get_weights_list(self) -> list[torch.Tensor]: + return [self.tilert_wkv_b_b, self.tilert_wkv_b_b_scales] + + def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Device sharding: split weights and scales per device. + + Args: + weights_map: Map from ref weight alias to tensor. + + Returns: + Map from tilert weight alias to (num_devices, ...) tensors. + """ + kv_b_proj_weight = weights_map[self.ref_weights_alias.wkv_b_weights] + kv_b_proj_weight_scale = weights_map[self.ref_weights_alias.wkv_b_scales] + + dev_heads = (self.num_devices, self.num_local_heads) + wkvb = kv_b_proj_weight.view(*dev_heads, self.wkvb_head_dim, self.wkvb_lora_rank)[ + :, :, -self.wkvb_v_head_dim : + ] + wkvb_scales = ( + kv_b_proj_weight_scale.view( + self.num_devices, + self.num_local_heads * self.wkvb_head_dim // self.model_args.block_size, + 1, + self.wkvb_lora_rank_qsize, + ) + .contiguous() + .repeat(1, 1, self.head_dim_scale_repeat, 1) + .view( + self.num_devices, + self.num_local_heads, + self.wkvb_head_qsize, + self.wkvb_lora_rank_qsize, + ) + .contiguous()[:, :, -self.wkvb_v_head_qsize :] + ) + return { + self.tilert_weights_alias.wkv_b_weights: wkvb.contiguous(), + self.tilert_weights_alias.wkv_b_scales: wkvb_scales.contiguous(), + } + + def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + sharding_size = self.num_local_heads * self.wkvb_head_dim + sharding_start = self.device_id * sharding_size + sharding_end = sharding_start + sharding_size + wkv_b = weight_dequant( + state_dict[self.ref_weights_alias.wkv_b_weights], + state_dict[self.ref_weights_alias.wkv_b_scales], + ) + wkv_b = wkv_b[sharding_start:sharding_end, :] + wkv_b = wkv_b.view(self.num_local_heads, self.wkvb_head_dim, self.wkvb_lora_rank) + self.ref_wkv_b = wkv_b[:, -self.wkvb_v_head_dim :] + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + self.tilert_wkv_b_b, self.tilert_wkv_b_b_scales = ProjoWKVbWeightsConverter( + self.model_args, self.num_devices + ).dispatch( + ProjoWKVbAlgorithm.GENERAL, + [ + state_dict[self.tilert_weights_alias.wkv_b_weights], + state_dict[self.tilert_weights_alias.wkv_b_scales], + ], + ) + + def init_random_weights(self) -> None: + wkv_b = init_func( + torch.empty( + self.model_args.n_heads * self.wkvb_head_dim, + self.wkvb_lora_rank, + dtype=torch.float8_e4m3fn, + ) + ) + wkv_b_scales = init_func( + torch.empty( + # Block quant should be applied to the original weight dimension (including head + # dimension) + self.model_args.n_heads * self.wkvb_head_dim // self.model_args.block_size, + self.wkvb_lora_rank_qsize, + dtype=torch.float32, + ) + ) + ref_state_dict = dict( + zip( + self.ref_weights_alias(), + [wkv_b, wkv_b_scales], + ) + ) + self.init_reference_weights(ref_state_dict) + sharded = self.device_sharding(ref_state_dict) + self.init_tilert_weights({k: v[self.device_id] for k, v in sharded.items()}) + + def init_tilert_vars(self, batch_size: int, seq_len: int) -> None: + self.output = torch.zeros( + (batch_size, seq_len, self.num_local_heads, self.wkvb_v_head_dim), + dtype=torch.bfloat16, + ) + self.profile_logs = get_profile_log_tensor() + self.is_var_init = True + + def golden_forward(self, x_out: torch.Tensor) -> torch.Tensor: + assert self.ref_wkv_b is not None + return torch.einsum("bshc,hdc->bshd", x_out, self.ref_wkv_b) + + def tilert_forward(self, x_out: torch.Tensor) -> torch.Tensor: + assert self.tilert_wkv_b_b is not None + assert self.tilert_wkv_b_b_scales is not None + assert self.output is not None + assert self.profile_logs is not None + projo_wkvb( + x_out, + self.tilert_wkv_b_b, + self.tilert_wkv_b_b_scales, + self.output, + self.profile_logs, + ) + if self.flag_enable_profiling_log: + parse_profile_log_tensor( + self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)] + ) + return self.output diff --git a/python/models/deepseek_v3_2/ops/projq_wqb.py b/python/models/deepseek_v3_2/ops/projq_wqb.py new file mode 100644 index 0000000..7287aa2 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/projq_wqb.py @@ -0,0 +1,295 @@ +"""ProjQB operation module.""" + +from dataclasses import dataclass +from enum import Enum + +import torch + +from tilert.models.base import TileRTModule, TilertWeightsConverter +from tilert.models.common import init_func, weight_dequant +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.profiler.utils import parse_profile_log_tensor +from tilert.utils import get_profile_log_tensor + +__all__ = [ + "projq_wqb", + "ProjqWqb", + "ProjqWqbAlgorithm", + "ProjqWqbWeightsConverter", + "ProjqWqbRefWeightsAlias", + "ProjqWqbTilertWeightsAlias", +] + + +def projq_wqb( + q_nope_in: torch.Tensor, + wkv_b_a: torch.Tensor, + wkv_b_a_scales: torch.Tensor, + output: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """ + Define the ProjqWqb operation. + + Args: + q_nope_in: Input tensor. + wkv_b_a: Weight tensor. + wkv_b_a_scales: Scale tensor. + output: Output tensor. + profile_logs: Profile logs tensor. + """ + if q_nope_in.shape[-1] == 128: + torch.ops.tilert.projq_wqb_op(q_nope_in, wkv_b_a, wkv_b_a_scales, output, profile_logs) + elif q_nope_in.shape[-1] == 192: + torch.ops.tilert.proj_qb_glm5_op(q_nope_in, wkv_b_a, wkv_b_a_scales, output, profile_logs) + + +class ProjqWqbAlgorithm(Enum): + """ProjqWqb algorithm""" + + GENERAL = "general" + + +class ProjqWqbWeightsConverter(TilertWeightsConverter): + def __init__(self, model_args: ModelArgs, num_devices: int, head_dim_block_size: int): + super().__init__(model_args, num_devices) + self.head_dim_block_size = head_dim_block_size + self.impl_block_size = 64 + + def convert_to_general(self, weights: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + with torch.inference_mode(): + tilert_wkv_b_weights, tilert_wkv_b_scales = weights + + n_local_heads = self.model_args.n_heads // self.num_devices + + wkv_b = tilert_wkv_b_weights + wkv_b_scales_raw = tilert_wkv_b_scales + wkv_b = wkv_b.view(n_local_heads, -1, self.model_args.kv_lora_rank) + assert self.model_args.kv_lora_rank % self.model_args.block_size == 0 + wkv_b_scales_raw = wkv_b_scales_raw.view( + n_local_heads, -1, self.model_args.kv_lora_rank // self.model_args.block_size + ) + wkv_b_a = wkv_b[:, : self.model_args.qk_nope_head_dim].transpose(1, 2).contiguous() + assert self.model_args.qk_nope_head_dim % self.head_dim_block_size == 0 + wkv_b_a_scales = ( + wkv_b_scales_raw[:, : self.model_args.qk_nope_head_dim // self.head_dim_block_size] + .transpose(1, 2) + .contiguous() + ) + if self.model_args.arch_name == "glm_5": + if wkv_b_a_scales.dtype != torch.float32: + print( + "Warning: ProjqWqbWeightsConverter: " + + f"wkv_b_a_scales.dtype: {wkv_b_a_scales.dtype} " + + "is not float32, convert to float32." + ) + wkv_b_a_scales = wkv_b_a_scales.to(torch.float32) + else: # DS v3.2, use bfloat16 for wkv_b_a_scales + wkv_b_a_scales = wkv_b_a_scales.to(torch.bfloat16) + # Tiling to fit tilert input + if self.head_dim_block_size != self.impl_block_size: + repeats = self.head_dim_block_size // self.impl_block_size + wkv_b_a_scales = wkv_b_a_scales.repeat(1, 1, repeats).contiguous() + + wkv_b_a = wkv_b_a.detach() + wkv_b_a_scales = wkv_b_a_scales.detach() + + return wkv_b_a, wkv_b_a_scales + + +@dataclass +class ProjqWqbRefWeightsAlias: + """Reference weights alias for ProjqWqb.""" + + wkv_b_weights = "self_attn.kv_b_proj.weight" + wkv_b_scales = "self_attn.kv_b_proj.weight_scale_inv" + + @property + def ref_tensor_alias(self) -> list[str]: + return [self.wkv_b_weights, self.wkv_b_scales] + + def __call__(self) -> list[str]: + return self.ref_tensor_alias + + +@dataclass +class ProjqWqbTilertWeightsAlias: + """TileRT weights alias for ProjqWqb.""" + + wkv_b_weights = "wkv_b1_weights" + wkv_b_scales = "wkv_b1_scales" + + @property + def tilert_tensor_alias(self) -> list[str]: + return [self.wkv_b_weights, self.wkv_b_scales] + + def __call__(self) -> list[str]: + return self.tilert_tensor_alias + + +class ProjqWqb(TileRTModule): + """ProjqWqb module: Q projection (wkv_b) for KV LoRA.""" + + def __init__( + self, + model_args: ModelArgs, + num_devices: int, + device_id: int = 0, + ref_weights_alias: ProjqWqbRefWeightsAlias | None = None, + ): + super().__init__( + self.__class__.__name__, + model_args=model_args, + num_devices=num_devices, + device_id=device_id, + ) + + self.tilert_weights_alias = ProjqWqbTilertWeightsAlias() + self.ref_weights_alias = ( + ref_weights_alias if ref_weights_alias is not None else ProjqWqbRefWeightsAlias() + ) + + self.ref_wkv_b: torch.Tensor | None = None + self.tilert_wkv_b_a: torch.Tensor | None = None + self.tilert_wkv_b_a_scales: torch.Tensor | None = None + self.output: torch.Tensor | None = None + self.profile_logs: torch.Tensor | None = None + + self.num_local_heads = self.model_args.n_heads // self.num_devices + + # lora dim and quant block size + self.wkvb_lora_rank = self.model_args.kv_lora_rank + self.wkvb_lora_rank_qsize = self.wkvb_lora_rank // self.model_args.block_size + + self.wkvb_head_dim = self.model_args.qk_nope_head_dim + self.model_args.v_head_dim + self.wkvb_nope_head_dim = self.model_args.qk_nope_head_dim + left_head_dim = self.wkvb_head_dim % self.model_args.block_size + if left_head_dim != 0: + assert self.model_args.block_size % left_head_dim == 0 + self.head_dim_block_size = left_head_dim + self.head_dim_scale_repeat = self.model_args.block_size // self.head_dim_block_size + else: + self.head_dim_scale_repeat = 1 + self.head_dim_block_size = self.model_args.block_size + self.wkvb_head_qsize = self.wkvb_head_dim // self.head_dim_block_size + self.wkvb_nope_head_qsize = self.wkvb_nope_head_dim // self.head_dim_block_size + + @property + def tilert_tensor_alias(self) -> list[str]: + return self.tilert_weights_alias.tilert_tensor_alias + + def get_weights_list(self) -> list[torch.Tensor]: + return [self.tilert_wkv_b_a, self.tilert_wkv_b_a_scales] + + def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Device sharding: split weights and scales per device. + + Args: + weights_map: Map from ref weight alias to tensor. + + Returns: + Map from tilert weight alias to (num_devices, ...) tensors. + """ + kv_b_proj_weight = weights_map[self.ref_weights_alias.wkv_b_weights] + kv_b_proj_weight_scale = weights_map[self.ref_weights_alias.wkv_b_scales] + + dev_heads = (self.num_devices, self.num_local_heads) + wkvb = kv_b_proj_weight.view(*dev_heads, self.wkvb_head_dim, self.wkvb_lora_rank)[ + :, :, : self.wkvb_nope_head_dim + ] + wkvb_scales = ( + kv_b_proj_weight_scale.view( + self.num_devices, + self.num_local_heads * self.wkvb_head_dim // self.model_args.block_size, + 1, + self.wkvb_lora_rank_qsize, + ) + .contiguous() + .repeat(1, 1, self.head_dim_scale_repeat, 1) + .view( + self.num_devices, + self.num_local_heads, + self.wkvb_head_qsize, + self.wkvb_lora_rank_qsize, + ) + .contiguous()[:, :, : self.wkvb_nope_head_qsize] + ) + return { + self.tilert_weights_alias.wkv_b_weights: wkvb.contiguous(), + self.tilert_weights_alias.wkv_b_scales: wkvb_scales.contiguous(), + } + + def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + sharding_size = self.num_local_heads * self.wkvb_head_dim + sharding_start = self.device_id * sharding_size + sharding_end = sharding_start + sharding_size + wkv_b = weight_dequant( + state_dict[self.ref_weights_alias.wkv_b_weights], + state_dict[self.ref_weights_alias.wkv_b_scales], + ) + wkv_b = wkv_b[sharding_start:sharding_end, :] + wkv_b = wkv_b.view(self.num_local_heads, self.wkvb_head_dim, self.wkvb_lora_rank) + self.ref_wkv_b = wkv_b[:, : self.wkvb_nope_head_dim] + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + self.tilert_wkv_b_a, self.tilert_wkv_b_a_scales = ProjqWqbWeightsConverter( + self.model_args, self.num_devices, self.head_dim_block_size + ).dispatch( + ProjqWqbAlgorithm.GENERAL, + [ + state_dict[self.tilert_weights_alias.wkv_b_weights], + state_dict[self.tilert_weights_alias.wkv_b_scales], + ], + ) + + def init_random_weights(self) -> None: + wkv_b = init_func( + torch.empty( + self.model_args.n_heads * self.wkvb_head_dim, + self.wkvb_lora_rank, + dtype=torch.float8_e4m3fn, + ) + ) + wkv_b_scales = init_func( + torch.empty( + # Block quant should be applied to the original weight dimension (including head + # dimension) + self.model_args.n_heads * self.wkvb_head_dim // self.model_args.block_size, + self.wkvb_lora_rank_qsize, + dtype=torch.float32, + ) + ) + ref_state_dict = dict(zip(self.ref_weights_alias(), [wkv_b, wkv_b_scales])) + self.init_reference_weights(ref_state_dict) + sharded = self.device_sharding(ref_state_dict) + self.init_tilert_weights({k: v[self.device_id] for k, v in sharded.items()}) + + def init_tilert_vars(self, batch_size: int, seq_len: int) -> None: + self.output = torch.zeros( + (batch_size, seq_len, self.num_local_heads, self.wkvb_lora_rank), dtype=torch.bfloat16 + ) + self.profile_logs = get_profile_log_tensor() + self.is_var_init = True + + def golden_forward(self, q_nope: torch.Tensor) -> torch.Tensor: + assert self.ref_wkv_b is not None + return torch.einsum("bshd,hdc->bshc", q_nope, self.ref_wkv_b) + + def tilert_forward(self, q_nope: torch.Tensor) -> torch.Tensor: + assert self.tilert_wkv_b_a is not None + assert self.tilert_wkv_b_a_scales is not None + assert self.output is not None + assert self.profile_logs is not None + projq_wqb( + q_nope, + self.tilert_wkv_b_a, + self.tilert_wkv_b_a_scales, + self.output, + self.profile_logs, + ) + if self.flag_enable_profiling_log: + parse_profile_log_tensor( + self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)] + ) + return self.output diff --git a/python/models/deepseek_v3_2/ops/projx_wis.py b/python/models/deepseek_v3_2/ops/projx_wis.py new file mode 100644 index 0000000..e264659 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/projx_wis.py @@ -0,0 +1,157 @@ +"""ProjxWis operation module.""" + +from dataclasses import dataclass + +import torch + +from tilert.models.base import TileRTModule +from tilert.models.common import init_func +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.profiler.utils import parse_profile_log_tensor +from tilert.utils import get_profile_log_tensor + +__all__ = [ + "projx_wis", + "ProjxWis", + "ProjxWisRefWeightsAlias", + "ProjxWisTilertWeightsAlias", +] + + +def projx_wis( + x_in: torch.Tensor, + w: torch.Tensor, + output: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """ + Define the ProjxWis operation. + + Args: + x_in: Input tensor. + w: Weight tensor. + output: Output tensor. + profile_logs: Profile logs tensor. + """ + if x_in.shape[-1] == 7168: + torch.ops.tilert.proj_w_op(x_in, w, output, profile_logs) + elif x_in.shape[-1] == 6144: + torch.ops.tilert.proj_w_glm5_op(x_in, w, output, profile_logs) + + +@dataclass +class ProjxWisRefWeightsAlias: + """Reference weights alias for ProjxWis.""" + + w_weights = "self_attn.indexer.weights_proj.weight" + + @property + def ref_tensor_alias(self) -> list[str]: + return [self.w_weights] + + def __call__(self) -> list[str]: + return self.ref_tensor_alias + + +@dataclass +class ProjxWisTilertWeightsAlias: + """TileRT weights alias for ProjxWis.""" + + w_weights = "id_score_weights" + + @property + def tilert_tensor_alias(self) -> list[str]: + return [self.w_weights] + + def __call__(self) -> list[str]: + return self.tilert_tensor_alias + + +class ProjxWis(TileRTModule): + """ProjxWis module: linear projection for indexer score weights.""" + + def __init__( + self, + model_args: ModelArgs, + num_devices: int, + device_id: int = 0, + ref_weights_alias: ProjxWisRefWeightsAlias | None = None, + ): + super().__init__( + self.__class__.__name__, + model_args=model_args, + num_devices=num_devices, + device_id=device_id, + ) + + self.tilert_weights_alias = ProjxWisTilertWeightsAlias() + self.ref_weights_alias = ( + ref_weights_alias if ref_weights_alias is not None else ProjxWisRefWeightsAlias() + ) + + # Backward compatibility: expose list for load_weights_for_layer etc. + self.ref_tensor_alias = self.ref_weights_alias.ref_tensor_alias + + self.ref_w: torch.Tensor | None = None + self.tilert_w: torch.Tensor | None = None + self.output: torch.Tensor | None = None + self.profile_logs: torch.Tensor | None = None + + self.dim = model_args.dim + self.index_n_heads = model_args.index_n_heads + + @property + def tilert_tensor_alias(self) -> list[str]: + return self.tilert_weights_alias.tilert_tensor_alias + + def get_weights_list(self) -> list[torch.Tensor]: + return [self.tilert_w] + + def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Device sharding: replicate weight for each device. + + Args: + weights_map: Map from ref weight alias to tensor. + + Returns: + Map from tilert weight alias to (num_devices, ...) tensors. + """ + w = weights_map[self.ref_weights_alias.w_weights][None, ...].repeat(self.num_devices, 1, 1) + return {self.tilert_weights_alias.w_weights: w} + + def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + w = state_dict[self.ref_weights_alias.w_weights] + self.ref_w = w.detach().clone().to(torch.bfloat16) + self.is_ref_weights_init = True + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + self.tilert_w = state_dict[self.tilert_weights_alias.w_weights].detach().clone() + self.is_tilert_weights_init = True + + def init_random_weights(self) -> None: + ref_w = init_func(torch.empty(self.index_n_heads, self.dim, dtype=torch.bfloat16)) + ref_state_dict = dict(zip(self.ref_weights_alias(), [ref_w])) + self.init_reference_weights(ref_state_dict) + sharded = self.device_sharding(ref_state_dict) + self.init_tilert_weights({k: v[self.device_id] for k, v in sharded.items()}) + + def init_tilert_vars(self, batch_size: int, seq_len: int) -> None: + self.output = torch.zeros((batch_size, seq_len, self.index_n_heads), dtype=torch.bfloat16) + self.profile_logs = get_profile_log_tensor() + self.is_var_init = True + + def golden_forward(self, x_norm: torch.Tensor) -> torch.Tensor: + assert self.ref_w is not None + return torch.nn.functional.linear(x_norm, self.ref_w) + + def tilert_forward(self, x_norm: torch.Tensor) -> torch.Tensor: + assert self.tilert_w is not None + assert self.output is not None + assert self.profile_logs is not None + projx_wis(x_norm, self.tilert_w, self.output, self.profile_logs) + if self.flag_enable_profiling_log: + parse_profile_log_tensor( + self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)] + ) + return self.output diff --git a/python/models/deepseek_v3_2/ops/qkv_rope.py b/python/models/deepseek_v3_2/ops/qkv_rope.py new file mode 100644 index 0000000..7a9a55a --- /dev/null +++ b/python/models/deepseek_v3_2/ops/qkv_rope.py @@ -0,0 +1,188 @@ +"""QKV Rope operation module. + +Unified for deepseek_v3_2 (n_local_heads=16) and glm_5 (n_local_heads=8). +Dispatches by q_pe.shape[2]: 16 -> qkv_rope_op, 8 -> qkv_rope_glm5_op. +""" + +from dataclasses import dataclass + +import torch + +from tilert.models.base import TileRTModule +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.models.utils import apply_rotary_emb +from tilert.profiler.utils import parse_profile_log_tensor +from tilert.utils import get_profile_log_tensor + +__all__ = [ + "qkv_rope", + "QKVRoPE", + "QKVRoPERefWeightsAlias", + "QKVRoPETilertWeightsAlias", +] + + +def qkv_rope( + pe_cache: torch.Tensor, + kv_cache: torch.Tensor, + rope_freqs: torch.Tensor, + cur_pos: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """ + Perform QKV Rope operation. + + Unified for deepseek_v3_2 (16 heads) and glm_5 (8 heads). Dispatches by + pe_cache (q_pe) shape[2]: 16 -> qkv_rope_op, 8 -> qkv_rope_glm5_op. + + Args: + pe_cache: Q PE tensor (bsz, seq, n_local_heads, qk_rope_head_dim). + kv_cache: K PE cache (bsz, seq, qk_rope_head_dim). + rope_freqs: Rope frequencies tensor. + cur_pos: Current position tensor. + profile_logs: Profile logs tensor. + """ + n_local_heads = pe_cache.shape[2] + qk_rope_head_dim = pe_cache.shape[3] + if qk_rope_head_dim != 64: + raise ValueError(f"Unsupported qk_rope_head_dim: {qk_rope_head_dim}") + + if n_local_heads == 16: + torch.ops.tilert.qkv_rope_op(pe_cache, kv_cache, rope_freqs, cur_pos, profile_logs) + elif n_local_heads == 8: + torch.ops.tilert.qkv_rope_glm5_op(pe_cache, kv_cache, rope_freqs, cur_pos, profile_logs) + else: + raise ValueError( + f"Unsupported n_local_heads: {n_local_heads}. " + "QKVRoPE supports n_local_heads=16 (deepseek_v3_2) or 8 (glm_5)." + ) + + +@dataclass +class QKVRoPERefWeightsAlias: + """Reference weights alias for QKVRoPE (no weights).""" + + @property + def ref_tensor_alias(self) -> list[str]: + return [] + + def __call__(self) -> list[str]: + return self.ref_tensor_alias + + +@dataclass +class QKVRoPETilertWeightsAlias: + """TileRT weights alias for QKVRoPE (no weights).""" + + @property + def tilert_tensor_alias(self) -> list[str]: + return [] + + def __call__(self) -> list[str]: + return self.tilert_tensor_alias + + +class QKVRoPE(TileRTModule): + """QKV RoPE module. Unified for deepseek_v3_2 and glm_5.""" + + def __init__( + self, + model_args: ModelArgs, + num_devices: int = 1, + device_id: int = 0, + layer_idx: int = 0, + ref_weights_alias: QKVRoPERefWeightsAlias | None = None, + ) -> None: + super().__init__( + self.__class__.__name__, + model_args=model_args, + num_devices=num_devices, + device_id=device_id, + layer_idx=layer_idx, + ) + self.tilert_weights_alias = QKVRoPETilertWeightsAlias() + self.ref_weights_alias = ( + ref_weights_alias if ref_weights_alias is not None else QKVRoPERefWeightsAlias() + ) + self.n_local_heads = model_args.n_heads // num_devices + self.qk_rope_head_dim = model_args.qk_rope_head_dim + self.profile_logs: torch.Tensor | None = None + + def get_weights_list(self) -> list[torch.Tensor]: + return [] + + def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + del weights_map + return {} + + def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + del state_dict + pass + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + del state_dict + pass + + def init_random_weights(self) -> None: + pass + + def init_tilert_vars(self, batch_size: int, seq_len: int) -> None: + del batch_size, seq_len + self.profile_logs = get_profile_log_tensor() + self.is_var_init = True + + def golden_forward( + self, + q_pe: torch.Tensor, + pe_cache: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + bsz: int, + seqlen: int, + ) -> torch.Tensor: + end_pos = start_pos + seqlen + + k_pe = pe_cache[:bsz, start_pos:end_pos] + k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) + pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) + + return apply_rotary_emb(q_pe, freqs_cis) + + def tilert_forward( + self, + q_pe: torch.Tensor, + pe_cache: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + bsz: int, + seqlen: int, + ) -> torch.Tensor: + assert self.profile_logs is not None + end_pos = start_pos + seqlen + + q_pe_rope = q_pe.clone() + rope_freqs = torch.view_as_real(freqs_cis).reshape(*freqs_cis.shape[:-1], -1) + cur_pos = torch.tensor([start_pos], dtype=torch.int32) + + qkv_rope( + q_pe_rope, pe_cache[:bsz, start_pos:end_pos], rope_freqs, cur_pos, self.profile_logs + ) + if self.flag_enable_profiling_log: + parse_profile_log_tensor( + self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)] + ) + + return q_pe_rope + + def __call__( + self, + q_pe: torch.Tensor, + pe_cache: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + bsz: int, + seqlen: int, + ) -> torch.Tensor: + if self.flag_enable_tilert: + return self.tilert_forward(q_pe, pe_cache, start_pos, freqs_cis, bsz, seqlen) + return self.golden_forward(q_pe, pe_cache, start_pos, freqs_cis, bsz, seqlen) diff --git a/python/models/deepseek_v3_2/ops/rmsnorm_expert_proj.py b/python/models/deepseek_v3_2/ops/rmsnorm_expert_proj.py new file mode 100644 index 0000000..ce867a7 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/rmsnorm_expert_proj.py @@ -0,0 +1,163 @@ +"""RMSNormExpertProj operation module.""" + +from dataclasses import dataclass + +import torch +from torch import nn + +from tilert.models.base import TileRTModule +from tilert.models.common import RMSNorm, init_func, linear +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.profiler.utils import parse_profile_log_tensor +from tilert.utils import get_profile_log_tensor + +__all__ = [ + "RMSNormExpertProj", + "RMSNormExpertProjRefWeightsAlias", + "RMSNormExpertProjTilertWeightsAlias", +] + + +@dataclass +class RMSNormExpertProjRefWeightsAlias: + """Reference weights alias for RMSNormExpertProj.""" + + post_attention_layernorm_weight = "post_attention_layernorm.weight" + mlp_gate_weight = "mlp.gate.weight" + + def __call__(self) -> list[str]: + return [self.post_attention_layernorm_weight, self.mlp_gate_weight] + + +@dataclass +class RMSNormExpertProjTilertWeightsAlias: + """TileRT weights alias for RMSNormExpertProj.""" + + unproj_o_gamma = "unproj_o_gamma" + exp_proj_weights = "exp_proj_weights" + + def __call__(self) -> list[str]: + return [self.unproj_o_gamma, self.exp_proj_weights] + + +class RMSNormExpertProj(TileRTModule): + """RMS Norm followed by expert projection.""" + + def __init__( + self, + model_args: ModelArgs, + num_devices: int, + device_id: int = 0, + ref_weights_alias: RMSNormExpertProjRefWeightsAlias | None = None, + tilert_weights_alias: RMSNormExpertProjTilertWeightsAlias | None = None, + ): + super().__init__( + type(self).__name__, + model_args=model_args, + num_devices=num_devices, + device_id=device_id, + ) + self.dim = model_args.dim + self.eps = model_args.eps + + self.ref_weights_alias = ( + ref_weights_alias + if ref_weights_alias is not None + else RMSNormExpertProjRefWeightsAlias() + ) + self.tilert_weights_alias = ( + tilert_weights_alias + if tilert_weights_alias is not None + else RMSNormExpertProjTilertWeightsAlias() + ) + + self.is_ref_weights_init = False + self.is_tilert_weights_init = False + + self.ref_rmsnorm: RMSNorm | None = None + self.ref_proj_weight: torch.Tensor | None = None + self.proj_weight = nn.Parameter( + init_func(torch.empty(model_args.n_routed_experts, model_args.dim)) + ) + self.n_routed_experts = model_args.n_routed_experts + + self.tilert_proj_weight: torch.Tensor | None = None + self.tilert_rms_norm_weight: torch.Tensor | None = None + + self.profile_logs = get_profile_log_tensor() + + def get_weights_list(self) -> list[torch.Tensor]: + return [self.tilert_rms_norm_weight, self.tilert_proj_weight] + + def device_sharding( + self, rms_norm_weight: torch.Tensor, proj_weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + return rms_norm_weight.float().contiguous(), proj_weight.contiguous() + + def init_reference_weights( + self, state_dict: dict[str, torch.Tensor], device_id: int | None = None + ) -> None: + del device_id + self.ref_rmsnorm = RMSNorm(self.dim, self.eps) + self.ref_rmsnorm.weight.data = state_dict[ + self.ref_weights_alias.post_attention_layernorm_weight + ] + self.ref_proj_weight = state_dict[self.ref_weights_alias.mlp_gate_weight] + self.is_ref_weights_init = True + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + self.tilert_proj_weight = ( + state_dict[self.tilert_weights_alias.exp_proj_weights].detach().clone() + ) + self.tilert_rms_norm_weight = ( + state_dict[self.tilert_weights_alias.unproj_o_gamma].detach().clone() + ) + self.is_tilert_weights_init = True + + def init_random_weights(self) -> None: + proj_weight = torch.randn(self.n_routed_experts, self.dim) + rms_norm_weight = torch.randn(self.dim, dtype=torch.float32) + ref_state_dict = dict( + zip( + self.ref_weights_alias(), + [rms_norm_weight, proj_weight], + ) + ) + self.init_reference_weights(ref_state_dict) + assert self.ref_rmsnorm is not None and self.ref_proj_weight is not None + sharded_weights = self.device_sharding(self.ref_rmsnorm.weight, self.ref_proj_weight) + self.init_tilert_weights(dict(zip(self.tilert_weights_alias(), sharded_weights))) + + def golden_forward( + self, x_in: torch.Tensor, residual: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + assert self.is_ref_weights_init, "Reference weights must be initialized before forward pass" + assert self.ref_rmsnorm is not None and self.ref_proj_weight is not None + norm_x = self.ref_rmsnorm(x_in, residual) + scores = linear(norm_x.view(-1, self.dim).float(), self.ref_proj_weight.float()) + return norm_x, scores + + def tilert_forward(self, x_in: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + assert self.is_tilert_weights_init, "Tilert weights must be initialized before forward pass" + assert self.tilert_rms_norm_weight is not None and self.tilert_proj_weight is not None + x_in = x_in.to(torch.bfloat16) + hidden_out = torch.zeros_like(x_in) + scores_out = torch.zeros( + (x_in.shape[0], x_in.shape[1], self.n_routed_experts), dtype=torch.float32 + ) + torch.ops.tilert.rmsnorm_expert_proj_op( + x_in, + self.tilert_rms_norm_weight, + self.tilert_proj_weight, + scores_out, + hidden_out, + self.profile_logs, + ) + if self.flag_enable_profiling_log: + parse_profile_log_tensor( + self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)] + ) + return hidden_out, scores_out + + def __call__(self, x_in: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return self.tilert_forward(x_in) diff --git a/python/models/deepseek_v3_2/ops/rmsnorm_head_proj.py b/python/models/deepseek_v3_2/ops/rmsnorm_head_proj.py new file mode 100644 index 0000000..6145b5b --- /dev/null +++ b/python/models/deepseek_v3_2/ops/rmsnorm_head_proj.py @@ -0,0 +1,339 @@ +"""RMSNormHeadProj operation module.""" + +from collections.abc import Callable +from dataclasses import dataclass +from enum import Enum + +import torch + +from tilert.models.base import TileRTModule, TilertWeightsConverter +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.utils import get_profile_log_tensor + +__all__ = [ + "rmsnorm_head_proj", + "rmsnorm_head_proj_glm5", + "RMSNormHeadProj", + "RMSNormHeadProjTilertWeightsAlias", +] + + +def rmsnorm_head_proj( + hidden_in: torch.Tensor, + gamma_in: torch.Tensor, + weight_in: torch.Tensor, + logits_out: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """RMS Norm Head Projection operation.""" + torch.ops.tilert.rmsnorm_head_proj_op( + hidden_in, + gamma_in, + weight_in, + logits_out, + profile_logs, + ) + + +def rmsnorm_head_proj_dsv32( + hidden_in: torch.Tensor, + gamma_in: torch.Tensor, + weight_in: torch.Tensor, + hidden_rmsnorm_out: torch.Tensor, + logits_out: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """RMS Norm Head Projection operation.""" + del hidden_rmsnorm_out + torch.ops.tilert.rmsnorm_head_proj_op( + hidden_in, + gamma_in, + weight_in, + logits_out, + profile_logs, + ) + + +def rmsnorm_head_proj_glm5( + hidden_in: torch.Tensor, + gamma_in: torch.Tensor, + weight_in: torch.Tensor, + hidden_rmsnorm_out: torch.Tensor, + logits_out: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """RMS Norm Head Projection operation.""" + torch.ops.tilert.rmsnorm_head_proj_glm5_op( + hidden_in, + gamma_in, + weight_in, + hidden_rmsnorm_out, + logits_out, + profile_logs, + ) + + +class RMSNormHeadProjAlgorithm(Enum): + """RMSNormHeadProj algorithm""" + + GENERAL = "general" + + +class RMSNormHeadProjWeightsConverter(TilertWeightsConverter): + """RMSNormHeadProj weights converter""" + + @staticmethod + def tilert_to_tilert_native_bf16_warp_gemv( + tilert_weight_in: torch.Tensor, + ) -> torch.Tensor: + """Convert TILERT weights to TILERT native bf16 warp gemv weights.""" + weights = tilert_weight_in.reshape(1010, 16, 7, 1024) + weights = weights.transpose(1, 2).reshape(7070, 16, 1024) + return weights.contiguous() + + def convert_to_general( + self, weights_list: list[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert the weights to general format. + + Args: + weights_list: List of weights. + + Returns: + Tuple of weights. + """ + args = self.model_args + assert args.arch_name == "deepseek_v3_2" or args.arch_name == "glm_5" + + with torch.inference_mode(): + rmsnorm_gamma, mat_in = weights_list + logits_dim = mat_in.shape[-2] + dim = mat_in.shape[-1] + num_steps = dim // 1024 + assert dim % 1024 == 0 + weights = mat_in.reshape(logits_dim // 16, 16, num_steps, 1024) + weights = weights.transpose(1, 2).reshape(logits_dim // 16 * num_steps, 16, 1024) + return rmsnorm_gamma.float(), weights + + +@dataclass +class RMSNormHeadProjTilertWeightsAlias: + """TileRT weights alias for RMSNormHeadProj.""" + + model_norm_weight = "model.norm.weight" + lm_head_weight = "lm_head.weight" + + @property + def tilert_tensor_alias(self) -> list[str]: + return [self.model_norm_weight, self.lm_head_weight] + + def __call__(self) -> list[str]: + return self.tilert_tensor_alias + + +class RMSNormHeadProj(TileRTModule): + """RMSNormHeadProj module""" + + def __init__( + self, + model_args: ModelArgs, + device_id: int, + num_devices: int, + algorithm: RMSNormHeadProjAlgorithm = RMSNormHeadProjAlgorithm.GENERAL, + ): + super().__init__( + self.__class__.__name__, + model_args=model_args, + device_id=device_id, + num_devices=num_devices, + ) + + self.arch_name = self.model_args.arch_name + self.dim = self.model_args.dim + self.logits_dim = self.model_args.vocab_size + self.algorithm = algorithm + self.eps = self.model_args.eps + + # reference weights + self.ref_rmsnorm_gamma: torch.Tensor | None = None + self.ref_head_proj: torch.Tensor | None = None + + # tilert weights + self.tilert_rmsnorm_gamma: torch.Tensor | None = None + self.tilert_head_proj: torch.Tensor | None = None + + # tilert vars + self.hidden_rmsnorm_out: torch.Tensor | None = None + self.hidden_out: torch.Tensor | None = None + + self.profile_logs: torch.Tensor | None = None + self.is_init = False + + # tilert_funcs + self.rmsnorm_head_proj_func: Callable | None = None + + if self.arch_name == "deepseek_v3_2": + self.rmsnorm_head_proj_func = rmsnorm_head_proj_dsv32 + elif self.arch_name == "glm_5": + self.rmsnorm_head_proj_func = rmsnorm_head_proj_glm5 + else: + raise ValueError(f"Unsupported architecture: {self.arch_name}") + + self.tilert_weights_alias = RMSNormHeadProjTilertWeightsAlias() + + # reference tensor aliases + self.ref_tensor_alias: list[str] = [ + "model.norm.weight", + "lm_head.weight", + ] + + @property + def tilert_tensor_alias(self) -> list[str]: + return self.tilert_weights_alias() + + def get_weights_list(self) -> list[torch.Tensor]: + """ + Get the weights list. + + Returns: + List of weights. + """ + return [self.tilert_rmsnorm_gamma, self.tilert_head_proj] + + def device_sharding( + self, + weights_dict: dict[str, torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Device sharding. + + Args: + weights_dict: Dictionary of weights. + key_prefix: Key prefix. + Returns: + Tuple of weights. + """ + rmsnorm_gamma_key = "model.norm.weight" + head_proj_key = "lm_head.weight" + rmsnorm_gamma = weights_dict[rmsnorm_gamma_key][None, ...] + # repeat number of devices times + rmsnorm_gamma = rmsnorm_gamma.repeat(self.num_devices, 1) + head_proj = weights_dict[head_proj_key] + + head_proj = head_proj.reshape(self.num_devices, -1, self.dim) + return rmsnorm_gamma.contiguous(), head_proj.contiguous() + + def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """ + Initialize the reference weights. + + Args: + state_dict: State dictionary. + device_id: Device ID. + """ + sharded_list = self.device_sharding(state_dict) + + gamma, head_proj = sharded_list[0][self.device_id], sharded_list[1][self.device_id] + self.ref_rmsnorm_gamma = gamma + self.ref_head_proj = head_proj + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """ + Initialize the tilert weights. + + Args: + state_dict: State dictionary. + """ + assert self.algorithm is not None + self.tilert_rmsnorm_gamma, self.tilert_head_proj = RMSNormHeadProjWeightsConverter( + self.model_args, self.num_devices + ).dispatch(self.algorithm, [state_dict[alias] for alias in self.tilert_weights_alias()]) + + def init_tilert_vars(self, batch_size: int, seq_len: int) -> None: + """ + Initialize the tilert variables. + + Args: + batch_size: Batch size. + seq_len: Sequence length. + """ + # tilert vars + self.hidden_rmsnorm_out = torch.zeros( + (batch_size, seq_len, self.dim), + dtype=torch.bfloat16, + device=f"cuda:{self.device_id}", + ) + self.hidden_out = torch.zeros( + (batch_size, seq_len, self.logits_dim // self.num_devices), + dtype=torch.float32, + device=f"cuda:{self.device_id}", + ) + self.profile_logs = get_profile_log_tensor(device=f"cuda:{self.device_id}") + self.is_init = True + + def init_random_weights(self, device_id: int = 0) -> None: + """Initialize the random weights.""" + rmsnorm_gamma = torch.randn(self.dim, dtype=torch.float32, device=f"cuda:{device_id}") + head_proj = torch.randn( + self.logits_dim, self.dim, dtype=torch.bfloat16, device=f"cuda:{device_id}" + ) + + tensor_list = [ + rmsnorm_gamma, + head_proj, + ] + state_dict = dict(zip(self.ref_tensor_alias, tensor_list)) + + self.init_reference_weights(state_dict) + sharded_list = self.device_sharding(state_dict) + sharded_state_dict = { + alias: sharded_list[i][self.device_id] + for i, alias in enumerate(self.tilert_weights_alias()) + } + self.init_tilert_weights(sharded_state_dict) + + def golden_forward( + self, + hidden_in: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass for the down-project module. + + Args: + hidden_in: Input hidden. + + Returns: + Output tensor. + """ + assert self.ref_rmsnorm_gamma is not None + assert self.ref_head_proj is not None + bsz = hidden_in.shape[0] + assert bsz == 1 + hidden_rmsnorm = torch.nn.functional.rms_norm( + hidden_in.float(), [hidden_in.size(-1)], self.ref_rmsnorm_gamma, self.eps + ) + return hidden_rmsnorm.float() @ self.ref_head_proj.T.float() + + def tilert_forward( + self, + hidden_in: torch.Tensor, + ) -> torch.Tensor: + assert self.rmsnorm_head_proj_func is not None + assert self.hidden_out is not None + + self.rmsnorm_head_proj_func( + hidden_in, + self.tilert_rmsnorm_gamma, + self.tilert_head_proj, + self.hidden_rmsnorm_out, + self.hidden_out, + self.profile_logs, + ) + return self.hidden_out + + def __call__( + self, + hidden_in: torch.Tensor, + ) -> torch.Tensor: + return self.golden_forward(hidden_in) diff --git a/python/models/deepseek_v3_2/ops/rmsnorm_kv.py b/python/models/deepseek_v3_2/ops/rmsnorm_kv.py new file mode 100644 index 0000000..d9c9af0 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/rmsnorm_kv.py @@ -0,0 +1,184 @@ +"""RMSNormKV operation module.""" + +from dataclasses import dataclass + +import torch + +from tilert.models.base import TileRTModule +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.profiler.utils import parse_profile_log_tensor +from tilert.utils import get_profile_log_tensor + +__all__ = [ + "rmsnorm_kv", + "KVRMSNorm", + "KVRMSNormRefWeightsAlias", + "KVRMSNormTilertWeightsAlias", +] + + +def rmsnorm_kv( + kv: torch.Tensor, + gamma: torch.Tensor, + cur_pos: torch.Tensor, + kv_cache: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """ + Define the RMSNormKV operation. + + Args: + kv: Input tensor. + gamma: Weight tensor. + cur_pos: Current position tensor. + kv_cache: Output tensor. + profile_logs: Profile logs tensor. + """ + torch.ops.tilert.rmsnorm_kv_op(kv, gamma, cur_pos, kv_cache, profile_logs) + + +@dataclass +class KVRMSNormRefWeightsAlias: + """Reference weights alias for KVRMSNorm.""" + + kv_norm_weight = "self_attn.kv_a_layernorm.weight" + + @property + def ref_tensor_alias(self) -> list[str]: + return [self.kv_norm_weight] + + def __call__(self) -> list[str]: + return self.ref_tensor_alias + + +@dataclass +class KVRMSNormTilertWeightsAlias: + """TileRT weights alias for KVRMSNorm.""" + + kv_norm_gamma = "kv_rmsnorm_gamma" + + @property + def tilert_tensor_alias(self) -> list[str]: + return [self.kv_norm_gamma] + + def __call__(self) -> list[str]: + return self.tilert_tensor_alias + + +class KVRMSNorm(TileRTModule): + """KVRMSNorm module: RMSNorm on KV tensor with in-place write to kv_cache.""" + + def __init__( + self, + model_args: ModelArgs, + num_devices: int, + device_id: int, + ref_weights_alias: KVRMSNormRefWeightsAlias | None = None, + tilert_weights_alias: KVRMSNormTilertWeightsAlias | None = None, + layer_idx: int = 0, + golden_weights_dir: str = "", + tilert_weights_dir: str = "", + ): + super().__init__( + self.__class__.__name__, + model_args=model_args, + num_devices=num_devices, + device_id=device_id, + layer_idx=layer_idx, + golden_weights_dir=golden_weights_dir, + tilert_weights_dir=tilert_weights_dir, + ) + + self.tilert_weights_alias = ( + tilert_weights_alias + if tilert_weights_alias is not None + else KVRMSNormTilertWeightsAlias() + ) + self.ref_weights_alias = ( + ref_weights_alias if ref_weights_alias is not None else KVRMSNormRefWeightsAlias() + ) + + self.kv_lora_rank = self.model_args.kv_lora_rank + self.eps = self.model_args.eps + + self.ref_norm_gamma: torch.Tensor | None = None + self.tilert_kv_norm_weight: torch.Tensor | None = None + self.profile_logs: torch.Tensor | None = None + + def get_weights_list(self) -> list[torch.Tensor]: + return [self.tilert_kv_norm_weight] + + def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Device sharding: replicate gamma for each device. + + Args: + weights_map: Map from ref weight alias to tensor. + + Returns: + Map from tilert weight alias to (num_devices, ...) tensors. + """ + gamma = weights_map[self.ref_weights_alias.kv_norm_weight][None, ...].repeat( + self.num_devices, 1 + ) + return {self.tilert_weights_alias.kv_norm_gamma: gamma} + + def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """Initialize reference weights from state dict.""" + self.ref_norm_gamma = state_dict[self.ref_weights_alias.kv_norm_weight].contiguous() + assert ( + self.ref_norm_gamma.shape[-1] == self.kv_lora_rank + ), f"kv_norm weight shape must be ({self.kv_lora_rank},), got {self.ref_norm_gamma.shape}" + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """Initialize TileRT weights from state dict.""" + gamma = state_dict[self.tilert_weights_alias.kv_norm_gamma] + self.tilert_kv_norm_weight = gamma.float().detach().clone().contiguous() + + def init_tilert_vars(self, batch_size: int, seq_len: int) -> None: + """Allocate TileRT profiling buffer.""" + del batch_size, seq_len + self.profile_logs = get_profile_log_tensor() + self.is_var_init = True + + def init_random_weights(self) -> None: + """Initialize random reference and TileRT weights for testing.""" + ref_state_dict = { + self.ref_weights_alias.kv_norm_weight: torch.randn( + self.kv_lora_rank, dtype=torch.float32 + ), + } + self.init_reference_weights(ref_state_dict) + sharded = self.device_sharding(ref_state_dict) + self.init_tilert_weights({k: v[self.device_id] for k, v in sharded.items()}) + + def golden_forward( + self, kv: torch.Tensor, kv_cache: torch.Tensor, start_pos: int, bsz: int, seqlen: int + ) -> None: + """Reference forward: RMSNorm and write to kv_cache.""" + assert self.ref_norm_gamma is not None + end_pos = start_pos + seqlen + out = torch.nn.functional.rms_norm( + kv.float(), [kv.size(-1)], self.ref_norm_gamma, self.eps + ).to(kv.dtype) + kv_cache[:bsz, start_pos:end_pos].copy_(out) + + def tilert_forward( + self, kv: torch.Tensor, kv_cache: torch.Tensor, start_pos: int, bsz: int, seqlen: int + ) -> None: + del seqlen + assert self.tilert_kv_norm_weight is not None + assert self.profile_logs is not None + cur_pos = torch.tensor([start_pos], dtype=torch.int32, device=kv.device) + rmsnorm_kv(kv, self.tilert_kv_norm_weight, cur_pos, kv_cache[:bsz], self.profile_logs) + if self.flag_enable_profiling_log: + parse_profile_log_tensor( + self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)] + ) + + def __call__( + self, kv: torch.Tensor, kv_cache: torch.Tensor, start_pos: int, bsz: int, seqlen: int + ) -> None: + if self.flag_enable_tilert: + return self.tilert_forward(kv, kv_cache, start_pos, bsz, seqlen) + return self.golden_forward(kv, kv_cache, start_pos, bsz, seqlen) diff --git a/python/models/deepseek_v3_2/ops/rmsnorm_proj_top1.py b/python/models/deepseek_v3_2/ops/rmsnorm_proj_top1.py new file mode 100644 index 0000000..c6ec1a5 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/rmsnorm_proj_top1.py @@ -0,0 +1,29 @@ +"""RMSNorm + head projection + top1 operation""" + +import torch + +__all__ = [ + "rmsnorm_proj_top1", +] + + +def rmsnorm_proj_top1( + hidden_in: torch.Tensor, + rmsnorm_gamma_in: torch.Tensor, + head_projection_weights_in: torch.Tensor, + token_id: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """ + Define the RMSNormProjTop1 operation. + + Args: + hidden_in: Input tensor. + rmsnorm_gamma_in: Weight tensor. + head_projection_weights_in: Weight tensor. + token_id: Output tensor. + profile_logs: Profile logs tensor. + """ + torch.ops.tilert.rmsnorm_proj_top1_op( + hidden_in, rmsnorm_gamma_in, head_projection_weights_in, token_id, profile_logs + ) diff --git a/python/models/deepseek_v3_2/ops/rmsnorm_projq_wqib.py b/python/models/deepseek_v3_2/ops/rmsnorm_projq_wqib.py new file mode 100644 index 0000000..7adcad6 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/rmsnorm_projq_wqib.py @@ -0,0 +1,689 @@ +"""RmsnormProjqWqib operation module.""" + +from dataclasses import dataclass +from enum import Enum + +import torch +from einops import rearrange + +from tilert.models.base import TileRTModule, TilertWeightsConverter +from tilert.models.common import weight_dequant +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.models.deepseek_v3_2.ops.expert_sel_up_gate_silu import ( + ExpertSelectUpGateSiLUWeightsConverter as WeightsConverter, +) +from tilert.profiler.utils import parse_profile_log_tensor +from tilert.utils import get_profile_log_tensor + +__all__ = [ + "RmsnormProjqWqib", + "RmsnormProjqWqibAlgorithm", + "RmsnormProjqWqibWeightsConverter", +] + + +def rmsnorm_projq_wqib_op( + q: torch.Tensor, + wq_b_full: torch.Tensor, + wq_b_full_scales: torch.Tensor, + q_norm_weight: torch.Tensor, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + iq: torch.Tensor, + profile_logs: torch.Tensor, + algorithm: str, +) -> None: + dim = q.shape[-1] + if dim == 1536: + impl_func = torch.ops.tilert.rmsnorm_proj_qb_iq_op + elif dim == 2048: + impl_func = torch.ops.tilert.rmsnorm_proj_qb_iq_glm5_op + else: + raise ValueError(f"Invalid dimension: {dim}") + impl_func( + q, + wq_b_full, + wq_b_full_scales, + q_norm_weight, + q_nope, + q_pe, + iq, + profile_logs, + algorithm, + ) + + +class RmsnormProjqWqibAlgorithm(Enum): + """RmsnormProjqWqib algorithm.""" + + BF16 = "bf16" + FP8 = "fp8" + FP16MMA = "fp16mma" + + +class RmsnormProjqWqibWeightsConverter(TilertWeightsConverter): + """Weights converter: common format to TileRT format.""" + + def __init__(self, model_args: ModelArgs, num_devices: int): + super().__init__(model_args=model_args, num_devices=num_devices) + + self.proc_groups = 8 + self.repeat = 16 + + self.block_size = self.model_args.block_size + self.n_local_heads = self.model_args.n_heads // self.num_devices + + self.q_lora_dim = self.model_args.q_lora_rank + self.q_lora_qdim = self.q_lora_dim // self.block_size + + self.qk_nope_head_dim = self.model_args.qk_nope_head_dim + self.qk_rope_head_dim = self.model_args.qk_rope_head_dim + self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + self.qk_dim = self.qk_head_dim * self.n_local_heads + self.qk_qdim = self.qk_dim // self.block_size + + self.index_n_heads = self.model_args.index_n_heads + self.index_head_dim = self.index_n_heads * self.model_args.index_head_dim + self.index_head_qdim = self.index_head_dim // self.block_size + + def _common_to_tilert_bf16( + self, + wq_b: torch.Tensor, + wq_b_scales_raw: torch.Tensor, + wq_b_iq: torch.Tensor, + wq_b_iq_scales: torch.Tensor, + rmsnorm_gamma: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert common weights to TileRT BF16 layout.""" + wq_b = wq_b.reshape(self.n_local_heads, self.qk_head_dim, self.q_lora_dim) + wq_b_nope = wq_b[:, : self.qk_nope_head_dim, :] + wq_b_nope = wq_b_nope.reshape( + self.n_local_heads, + self.proc_groups, + self.qk_nope_head_dim // self.proc_groups, + self.q_lora_dim, + ) + wq_b_pe = wq_b[:, self.qk_nope_head_dim :, :] + wq_b_pe = wq_b_pe.reshape( + self.n_local_heads, + self.proc_groups, + self.qk_rope_head_dim // self.proc_groups, + self.q_lora_dim, + ) + wq_b = torch.cat([wq_b_nope, wq_b_pe], dim=2) + wq_b = wq_b.reshape(self.qk_dim, self.q_lora_dim) + wq_b_full = torch.cat([wq_b, wq_b_iq], dim=0) + + wq_b_scales_iq_raw = wq_b_iq_scales + wq_b_scales_t16 = ( + wq_b_scales_raw.reshape((self.qk_qdim, 1, self.q_lora_qdim)) + .repeat(1, self.repeat, 1) + .reshape(self.qk_qdim * self.repeat, self.q_lora_qdim) + ) + wq_b_scales_t16 = wq_b_scales_t16.reshape( + self.n_local_heads, self.qk_head_dim // self.proc_groups, self.q_lora_qdim + ) + wq_b_scales_t16_nope = wq_b_scales_t16[:, : self.qk_nope_head_dim // 8] + wq_b_scales_t16_pe = wq_b_scales_t16[:, self.qk_nope_head_dim // 8 :] + wq_b_scales_t16_nope = wq_b_scales_t16_nope.reshape( + self.n_local_heads, + self.proc_groups, + self.qk_nope_head_dim // 8 // self.proc_groups, + self.q_lora_qdim, + ) + wq_b_scales_t16_pe = wq_b_scales_t16_pe.reshape( + self.n_local_heads, + self.proc_groups, + self.qk_rope_head_dim // 8 // self.proc_groups, + self.q_lora_qdim, + ) + wq_b_scales_t16 = torch.cat([wq_b_scales_t16_nope, wq_b_scales_t16_pe], dim=2) + wq_b_scales_t16 = wq_b_scales_t16.reshape(-1, self.q_lora_qdim) + wq_b_scales_full = torch.cat([wq_b_scales_t16, wq_b_scales_iq_raw], dim=0) + + return ( + wq_b_full.detach().clone(), + wq_b_scales_full.detach().clone(), + rmsnorm_gamma.float().detach().clone(), + ) + + def _common_to_tilert_fp8( + self, + wq_b: torch.Tensor, + wq_b_scales_raw: torch.Tensor, + wq_b_iq: torch.Tensor, + wq_b_iq_scales_raw: torch.Tensor, + rmsnorm_gamma: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert common weights to TileRT FP8 MMA layout.""" + # Reshape wq_b: simple split of nope and pe, then concatenate + wq_b = wq_b.reshape(self.n_local_heads, self.qk_head_dim, self.q_lora_dim) + wq_b_nope = wq_b[:, : self.qk_nope_head_dim, :].reshape(-1, self.q_lora_dim) + wq_b_pe = wq_b[:, self.qk_nope_head_dim :, :].reshape(-1, self.q_lora_dim) + wq_b = torch.cat([wq_b_nope, wq_b_pe], dim=0) + + # Process scales: expand and split nope/pe similarly to weights + m_scale_group = self.block_size // self.repeat + wq_b_scales_t16 = ( + wq_b_scales_raw.reshape((self.qk_qdim, 1, self.q_lora_qdim)) + .repeat(1, self.repeat, 1) + .reshape(-1, self.qk_head_dim // m_scale_group, self.q_lora_qdim) + ) + + # Split nope and pe parts + wq_b_scales_nope = wq_b_scales_t16[:, : self.qk_nope_head_dim // m_scale_group, :].reshape( + [-1, self.q_lora_qdim] + ) + wq_b_scales_pe = wq_b_scales_t16[:, self.qk_nope_head_dim // m_scale_group :, :].reshape( + [-1, self.q_lora_qdim] + ) + wq_b_scales_t16 = torch.cat([wq_b_scales_nope, wq_b_scales_pe], dim=0) + + # Process wq_b_iq scales + wq_b_iq_scales_t16 = ( + wq_b_iq_scales_raw.reshape([self.index_head_qdim, 1, self.q_lora_qdim]) + .repeat([1, self.repeat, 1]) + .reshape((-1, self.q_lora_qdim)) + ) + + # Concatenate weights and scales + wq_b_raw = torch.cat([wq_b, wq_b_iq], dim=0) + page_k = self.q_lora_qdim + total_out_dim = self.qk_dim + self.index_head_dim + total_out_qdim = total_out_dim // self.block_size + wq_b_scales_full = ( + torch.cat( + [wq_b_scales_t16.to(torch.float32), wq_b_iq_scales_t16.to(torch.float32)], dim=0 + ) + .reshape((total_out_qdim, self.repeat, page_k, self.q_lora_qdim // page_k)) + .permute([0, 2, 1, 3]) + .contiguous() + .view(torch.float8_e4m3fn) + ) + + wq_b_raw = wq_b_raw.reshape( + [total_out_qdim, 128 // 16, 16, page_k, self.q_lora_dim // 32 // page_k, 32] + ).permute([0, 3, 1, 4, 2, 5]) + wq_b_raw = WeightsConverter._swizzle_mma_16x32(wq_b_raw) + + tilert_wq_b_full = torch.cat( + [ + wq_b_raw.reshape((total_out_qdim, page_k, -1)), + wq_b_scales_full.reshape([total_out_qdim, page_k, -1]), + ], + -1, + ).contiguous() + # TODO: use fp32 scale for glm_5 + tilert_wq_b_full_scales = torch.zeros(1, dtype=torch.bfloat16) + tilert_q_norm_weight = rmsnorm_gamma.float().detach().clone() + return tilert_wq_b_full, tilert_wq_b_full_scales, tilert_q_norm_weight + + @staticmethod + def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor: + assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16 + # PTX isa fig.88 + pre_shape = mat_in.shape[:-2] + mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4) + return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2) + + @staticmethod + def _swizzle_mma_16x16_for_16x2048_4pages(mat_in: torch.Tensor) -> torch.Tensor: + assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 2048 + pre_shape = mat_in.shape[:-2] + mat_in = mat_in.reshape(*pre_shape, 16, 4, 512).transpose(-3, -2) + mat_in = mat_in.reshape(*pre_shape, 4, 16, 32, 16).transpose(-3, -2) + mat_in = RmsnormProjqWqibWeightsConverter._swizzle_mma_16x16(mat_in) + return mat_in.contiguous() + + def _common_to_tilert_fp16mma( + self, + wq_b: torch.Tensor, + wq_b_scale: torch.Tensor, + wq_b_iq: torch.Tensor, + wq_b_iq_scale: torch.Tensor, + q_norm_weight: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert common weights to TileRT FP16 MMA layout.""" + assert self.model_args.arch_name == "glm_5", "Only GLM-5 supports FP16 MMA" + + if wq_b_scale.dtype != torch.float32: + print( + "Warning: RmsnormProjqWqibWeightsConverter: " + + f"wq_b_scale.dtype: {wq_b_scale.dtype} " + + "is not float32, convert to float32." + ) + wq_b_scale = wq_b_scale.to(torch.float32) + if wq_b_iq_scale.dtype != torch.float32: + print( + "Warning: RmsnormProjqWqibWeightsConverter: " + + f"wq_b_iq_scale.dtype: {wq_b_iq_scale.dtype} " + + "is not float32, convert to float32." + ) + wq_b_iq_scale = wq_b_iq_scale.to(torch.float32) + + sms = 128 # use 128 sms for glm_5 + pages = 4 + qk_dim = self.qk_head_dim * self.n_local_heads + qk_dim_per_sm = qk_dim // sms # 16 per sm + qk_nope_dim = self.n_local_heads * self.qk_nope_head_dim + qk_pe_dim = self.n_local_heads * self.qk_rope_head_dim + iq_dim_per_sm = self.index_head_dim // sms # 32 per sm + + wq_b_scale = wq_b_scale.reshape( + self.n_local_heads, self.qk_head_dim // self.block_size, 1, self.q_lora_qdim + ).repeat( + 1, 1, self.block_size, 1 + ) # 2048, 2048//128 + + wq_b_scale = wq_b_scale.reshape(self.n_local_heads, self.qk_head_dim, -1) + wq_b_nope_scale = ( + wq_b_scale[:, : self.qk_nope_head_dim, :] + .reshape(qk_nope_dim // qk_dim_per_sm, qk_dim_per_sm, pages, self.q_lora_qdim // pages) + .transpose(1, 2) # (96, 4, 16, 4) for glm_5 + ) + + wq_b_pe_scale = ( + wq_b_scale[:, self.qk_nope_head_dim :, :] + .reshape(qk_pe_dim // qk_dim_per_sm, qk_dim_per_sm, pages, self.q_lora_qdim // pages) + .transpose(1, 2) # (32, 4, 16, 4) for glm_5 + ) + wq_b_scale = torch.cat([wq_b_nope_scale, wq_b_pe_scale], dim=0) + wq_b_scale = wq_b_scale[:, :, 0, :] # (128, 4, 4) for glm_5 + + wq_b_iq_scale = wq_b_iq_scale.reshape(self.index_head_qdim, 1, self.q_lora_qdim).repeat( + 1, self.block_size, 1 + ) # (4096, 16) for glm_5 + wq_b_iq_scale = wq_b_iq_scale.reshape( + sms, iq_dim_per_sm, pages, self.q_lora_qdim // pages + ).transpose(1, 2) + wq_b_iq_scale = wq_b_iq_scale[:, :, 0, :] # (128, 4, 4) for glm_5 + + wq_b_full_scales = ( + torch.cat([wq_b_scale, wq_b_iq_scale], dim=-1).contiguous().view(torch.float8_e4m3fn) + ) # (128, 4, 8x4) for glm_5 + + wq_b = wq_b.reshape(self.n_local_heads, self.qk_head_dim, self.q_lora_dim) + wq_b_nope = wq_b[:, : self.qk_nope_head_dim, :].reshape(-1, self.q_lora_dim) # 8x192, 2048 + wq_b_nope = RmsnormProjqWqibWeightsConverter._swizzle_mma_16x16_for_16x2048_4pages( + wq_b_nope.reshape(qk_nope_dim // qk_dim_per_sm, qk_dim_per_sm, self.q_lora_dim) + ) + wq_b_nope = wq_b_nope.reshape(qk_nope_dim // qk_dim_per_sm, pages, qk_dim_per_sm, -1) + # (96, 4, 16, 512) for glm_5 + + wq_b_pe = wq_b[:, self.qk_nope_head_dim :, :].reshape(-1, self.q_lora_dim) # 8x64, 2048 + wq_b_pe = RmsnormProjqWqibWeightsConverter._swizzle_mma_16x16_for_16x2048_4pages( + wq_b_pe.reshape(qk_pe_dim // qk_dim_per_sm, qk_dim_per_sm, self.q_lora_dim) + ) + wq_b_pe = wq_b_pe.reshape(qk_pe_dim // qk_dim_per_sm, pages, qk_dim_per_sm, -1) + # (32, 4, 16, 512) for glm_5 + wq_b = torch.cat([wq_b_nope, wq_b_pe], dim=0) + # (128, 4, 16, 512) for glm_5 + + wq_b_iq = RmsnormProjqWqibWeightsConverter._swizzle_mma_16x16_for_16x2048_4pages( + wq_b_iq.reshape(sms, 2, iq_dim_per_sm // 2, self.q_lora_dim) + ) + wq_b_iq = ( + wq_b_iq.reshape(sms, 2, pages, iq_dim_per_sm // 2, -1) + .transpose(1, 2) + .reshape(sms, pages, iq_dim_per_sm, -1) + ) + # (128, 4, 32, 512) for glm_5 + wq_b = torch.cat([wq_b, wq_b_iq], dim=2) + wq_b = wq_b.reshape(sms, pages, -1) + # (128, 4, 48*512) for glm_5 + wq_b_scales_padding = torch.zeros( + sms, + pages, + 128 - wq_b_full_scales.shape[-1], + dtype=torch.float8_e4m3fn, + device=wq_b.device, + ) # append 128-byte aligned scale: (128, 4, 24704) for glm_5 + tilert_wq_b_full = torch.cat( + [wq_b, wq_b_full_scales, wq_b_scales_padding], dim=-1 + ).contiguous() + tilert_wq_b_dummy_scales = torch.zeros(1, dtype=torch.bfloat16) + tilert_q_norm_weight = q_norm_weight.float().detach().clone() + return tilert_wq_b_full, tilert_wq_b_dummy_scales, tilert_q_norm_weight + + def convert_to_bf16( + self, weights: list[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert common-format weights to TileRT BF16 layout. + + Args: + weights: [q_norm_weight, wq_b, wq_b_scale, wq_b_iq, wq_b_iq_scale]. + """ + with torch.inference_mode(): + wq_b, wq_b_scale, wq_b_iq, wq_b_iq_scale, q_norm_weight = weights + if self.model_args.arch_name == "glm_5": + if wq_b_scale.dtype != torch.float32: + print( + "Warning: RmsnormProjqWqibWeightsConverter: " + + f"wq_b_scale.dtype: {wq_b_scale.dtype} " + + "is not float32, convert to float32." + ) + wq_b_scales = wq_b_scale.to(torch.float32) + wq_b_iq_scales = wq_b_iq_scale.to(torch.float32) + return self._common_to_tilert_bf16( + wq_b, + wq_b_scales, + wq_b_iq, + wq_b_iq_scales, + q_norm_weight, + ) + + # DS v3.2, use bfloat16 for wq_b_scale and wq_b_iq_scale + wq_b_scales_bf16 = wq_b_scale.to(torch.bfloat16) + wq_b_iq_scales_bf16 = wq_b_iq_scale.to(torch.bfloat16) + return self._common_to_tilert_bf16( + wq_b, + wq_b_scales_bf16, + wq_b_iq, + wq_b_iq_scales_bf16, + q_norm_weight, + ) + + def convert_to_fp8( + self, weights: list[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert common-format weights to TileRT FP8 MMA layout. + + Args: + weights: [q_norm_weight, wq_b, wq_b_scale, wq_b_iq, wq_b_iq_scale]. + """ + with torch.inference_mode(): + wq_b, wq_b_scale, wq_b_iq, wq_b_iq_scale, q_norm_weight = weights + return self._common_to_tilert_fp8( + wq_b, + wq_b_scale, + wq_b_iq, + wq_b_iq_scale, + q_norm_weight, + ) + + def convert_to_fp16mma( + self, weights: list[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert common-format weights to TileRT FP16 MMA layout. + + Args: + weights: [q_norm_weight, wq_b, wq_b_scale, wq_b_iq, wq_b_iq_scale]. + """ + with torch.inference_mode(): + wq_b, wq_b_scale, wq_b_iq, wq_b_iq_scale, q_norm_weight = weights + return self._common_to_tilert_fp16mma( + wq_b, + wq_b_scale, + wq_b_iq, + wq_b_iq_scale, + q_norm_weight, + ) + + +@dataclass +class RmsnormProjqWqibRefWeightsAlias: + """Reference weights alias for RmsnormProjqWqib.""" + + rmsnorm_gamma = "self_attn.q_a_layernorm.weight" + wqb_weights = "self_attn.q_b_proj.weight" + wqb_scales = "self_attn.q_b_proj.weight_scale_inv" + wi_weights = "self_attn.indexer.wq_b.weight" + wi_scales = "self_attn.indexer.wq_b.weight_scale_inv" + + @property + def ref_tensor_alias(self) -> list[str]: + return [ + self.rmsnorm_gamma, + self.wqb_weights, + self.wqb_scales, + self.wi_weights, + self.wi_scales, + ] + + def __call__(self) -> list[str]: + return self.ref_tensor_alias + + +@dataclass +class RmsnormProjqWqibTilertWeightsAlias: + """TileRT weights alias for RmsnormProjqWqib.""" + + rmsnorm_gamma = "q_rmsnorm_gamma" + wqb_weights = "wqb_weights" + wqb_scales = "wqb_scales" + wi_weights = "wi_weights" + wi_scales = "wi_scales" + + @property + def tilert_tensor_alias(self) -> list[str]: + return [ + self.rmsnorm_gamma, + self.wqb_weights, + self.wqb_scales, + self.wi_weights, + self.wi_scales, + ] + + def __call__(self) -> list[str]: + return self.tilert_tensor_alias + + +class RmsnormProjqWqib(TileRTModule): + """RmsnormProjqWqib module: RMSNorm + Q projection (wq_b + wq_b_iq).""" + + def __init__( + self, + model_args: ModelArgs, + device_id: int, + num_devices: int, + ref_weights_alias: RmsnormProjqWqibRefWeightsAlias | None = None, + ): + super().__init__( + self.__class__.__name__, + model_args=model_args, + device_id=device_id, + num_devices=num_devices, + ) + + self.tilert_weights_alias = RmsnormProjqWqibTilertWeightsAlias() + self.ref_weights_alias = ( + ref_weights_alias + if ref_weights_alias is not None + else RmsnormProjqWqibRefWeightsAlias() + ) + + self.n_local_heads = model_args.n_heads // num_devices + self.q_lora_rank = model_args.q_lora_rank + self.index_n_heads = model_args.index_n_heads + self.head_dim = model_args.index_head_dim + self.index_head_dim = model_args.index_n_heads * model_args.index_head_dim + self.n_heads = model_args.n_heads + self.qk_head_dim = model_args.qk_nope_head_dim + model_args.qk_rope_head_dim + self.qk_local_dim = self.qk_head_dim * self.n_local_heads + self.qk_nope_head_dim = model_args.qk_nope_head_dim + self.qk_rope_head_dim = model_args.qk_rope_head_dim + + # quantize block size + self.block_size = model_args.block_size + self.q_lora_qdim = self.q_lora_rank // self.block_size + self.qk_local_qdim = self.qk_local_dim // self.block_size + self.index_head_qdim = self.index_head_dim // self.block_size + self.eps = model_args.eps + + self.ref_q_norm: torch.Tensor | None = None + self.ref_wq_b: torch.Tensor | None = None + self.ref_wq_b_iq: torch.Tensor | None = None + + self.tilert_wq_b_full: torch.Tensor | None = None + self.tilert_wq_b_full_scales: torch.Tensor | None = None + self.tilert_q_norm_weight: torch.Tensor | None = None + + self.q_nope: torch.Tensor | None = None + self.q_pe: torch.Tensor | None = None + self.iq: torch.Tensor | None = None + + self.profile_logs: torch.Tensor | None = None + + def get_weights_list(self) -> list[torch.Tensor]: + return [self.tilert_q_norm_weight, self.tilert_wq_b_full, self.tilert_wq_b_full_scales] + + def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Device sharding.""" + gamma = weights_map[self.ref_weights_alias.rmsnorm_gamma][None, ...].repeat( + self.num_devices, 1 + ) + + sharded_wqb_weights = weights_map[self.ref_weights_alias.wqb_weights].reshape( + self.num_devices, self.qk_local_dim, self.q_lora_rank + ) + sharded_wi_weights = weights_map[self.ref_weights_alias.wi_weights][None, ...].repeat( + self.num_devices, 1, 1 + ) + + sharded_wqb_scales = weights_map[self.ref_weights_alias.wqb_scales].reshape( + self.num_devices, self.qk_local_qdim, self.q_lora_qdim + ) + sharded_wi_scales = weights_map[self.ref_weights_alias.wi_scales][None, ...].repeat( + self.num_devices, 1, 1 + ) + + return { + self.tilert_weights_alias.rmsnorm_gamma: gamma, + self.tilert_weights_alias.wqb_weights: sharded_wqb_weights, + self.tilert_weights_alias.wqb_scales: sharded_wqb_scales, + self.tilert_weights_alias.wi_weights: sharded_wi_weights, + self.tilert_weights_alias.wi_scales: sharded_wi_scales, + } + + def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """Initialize reference weights from common-format state dict.""" + self.ref_q_norm = state_dict[self.ref_weights_alias.rmsnorm_gamma] + qk_local_dim_start = self.qk_local_dim * self.device_id + qk_local_qdim_start = qk_local_dim_start // self.block_size + qk_local_dim_end = qk_local_dim_start + self.qk_local_dim + qk_local_qdim_end = qk_local_dim_end // self.block_size + wq_b = weight_dequant( + state_dict[self.ref_weights_alias.wqb_weights][qk_local_dim_start:qk_local_dim_end], + state_dict[self.ref_weights_alias.wqb_scales][qk_local_qdim_start:qk_local_qdim_end], + ) + wq_b_iq = weight_dequant( + state_dict[self.ref_weights_alias.wi_weights], + state_dict[self.ref_weights_alias.wi_scales], + ) + self.ref_wq_b = wq_b.contiguous() + self.ref_wq_b_iq = wq_b_iq.contiguous() + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """Initialize TileRT weights from common-format state dict.""" + weights = [ + state_dict[_k] + for _k in [ + self.tilert_weights_alias.wqb_weights, + self.tilert_weights_alias.wqb_scales, + self.tilert_weights_alias.wi_weights, + self.tilert_weights_alias.wi_scales, + self.tilert_weights_alias.rmsnorm_gamma, + ] + ] + assert self.algorithm is not None, "Algorithm is not set" + self.tilert_wq_b_full, self.tilert_wq_b_full_scales, self.tilert_q_norm_weight = ( + RmsnormProjqWqibWeightsConverter(self.model_args, self.num_devices).dispatch( + self.algorithm, weights + ) + ) + + def init_random_weights(self) -> None: + """Initialize random reference and TileRT weights for testing.""" + q_norm = torch.randn(self.q_lora_rank, dtype=torch.float32) + wq_b = torch.randn( + self.num_devices * self.qk_local_dim, self.q_lora_rank, dtype=torch.bfloat16 + ).to(torch.float8_e4m3fn) + scale_dtype = torch.float32 if self.model_args.arch_name == "glm_5" else torch.bfloat16 + wq_b_scale = torch.randn( + self.num_devices * self.qk_local_qdim, self.q_lora_qdim, dtype=scale_dtype + ) + wq_b_iq = torch.randn(self.index_head_dim, self.q_lora_rank, dtype=torch.bfloat16).to( + torch.float8_e4m3fn + ) + wq_b_iq_scale = torch.randn(self.index_head_qdim, self.q_lora_qdim, dtype=scale_dtype) + ref_state = { + self.ref_weights_alias.rmsnorm_gamma: q_norm, + self.ref_weights_alias.wqb_weights: wq_b, + self.ref_weights_alias.wqb_scales: wq_b_scale, + self.ref_weights_alias.wi_weights: wq_b_iq, + self.ref_weights_alias.wi_scales: wq_b_iq_scale, + } + + self.init_reference_weights(ref_state) + self.init_tilert_weights( + {_k: _v[self.device_id] for _k, _v in self.device_sharding(ref_state).items()} + ) + + def init_tilert_vars(self, batch_size: int, seq_len: int) -> None: + """Allocate TileRT output buffers.""" + self.q_nope = torch.zeros( + batch_size, seq_len, self.n_local_heads, self.qk_nope_head_dim, dtype=torch.bfloat16 + ) + self.q_pe = torch.zeros( + batch_size, seq_len, self.n_local_heads, self.qk_rope_head_dim, dtype=torch.bfloat16 + ) + self.iq = torch.zeros( + batch_size, seq_len, self.index_n_heads, self.head_dim, dtype=torch.bfloat16 + ) + self.profile_logs = get_profile_log_tensor() + self.is_var_init = True + + def golden_forward(self, q: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Reference forward: RMSNorm + linear projections.""" + assert self.ref_q_norm is not None + assert self.ref_wq_b is not None + assert self.ref_wq_b_iq is not None + + bsz, seqlen, _ = q.shape + if bsz != 1 or seqlen not in [1, 2, 4]: + raise ValueError(f"Invalid batch size or sequence length: bsz={bsz}, seqlen={seqlen}") + + qr = torch.nn.functional.rms_norm(q.float(), [q.size(-1)], self.ref_q_norm, self.eps).to( + q.dtype + ) + + q = torch.matmul(qr, self.ref_wq_b.T) + q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_idx = torch.matmul(qr, self.ref_wq_b_iq.T) + q_idx = rearrange(q_idx, "b s (h d) -> b s h d", d=self.head_dim) + return q_nope, q_pe, q_idx + + def tilert_forward(self, q: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert self.tilert_wq_b_full is not None + assert self.tilert_wq_b_full_scales is not None + assert self.tilert_q_norm_weight is not None + assert self.q_nope is not None + assert self.q_pe is not None + assert self.iq is not None + assert self.profile_logs is not None + + bsz, seqlen, _ = q.shape + if bsz != 1 or seqlen not in [1, 2, 4]: + raise ValueError(f"Invalid batch size or sequence length: bsz={bsz}, seqlen={seqlen}") + + assert self.algorithm is not None, "Algorithm is not set" + + rmsnorm_projq_wqib_op( + q, + self.tilert_wq_b_full, + self.tilert_wq_b_full_scales, + self.tilert_q_norm_weight, + self.q_nope, + self.q_pe, + self.iq, + self.profile_logs, + self.algorithm.value, + ) + + if self.flag_enable_profiling_log: + torch.cuda.synchronize() + parse_profile_log_tensor( + self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)] + ) + return self.q_nope, self.q_pe, self.iq diff --git a/python/models/deepseek_v3_2/ops/rmsnorm_projx_wqkvia.py b/python/models/deepseek_v3_2/ops/rmsnorm_projx_wqkvia.py new file mode 100644 index 0000000..d6538ed --- /dev/null +++ b/python/models/deepseek_v3_2/ops/rmsnorm_projx_wqkvia.py @@ -0,0 +1,1095 @@ +"""RMSNormProjxWqkvia operation module.""" + +from collections.abc import Callable +from dataclasses import dataclass +from enum import Enum + +# from typing import Any +import torch + +from tilert.models.base import TileRTModule, TilertWeightsConverter +from tilert.models.common import weight_dequant +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.models.deepseek_v3_2.ops.rmsnorm_quant import rmsnorm_quant +from tilert.profiler.utils import parse_profile_log_tensor +from tilert.utils import get_profile_log_tensor + +__all__ = [ + "RMSNormProjQAKVAKIWeightsConverter", + "RMSNormProjxWqkviaAlgorithm", + "RMSNormProjxWqkvia", + "RMSNormProjxWqkviaRefWeightsAlias", + "RMSNormProjxWqkviaTilertWeightsAlias", + "rmsnorm_projx_wqkvia", + "projx_wqkvia", +] + + +def rmsnorm_projx_wqkvia( + x_in: torch.Tensor, + wqkv_a: torch.Tensor, + wqkv_a_scales: torch.Tensor, + rmsnorm_gamma: torch.Tensor, + cur_pos: torch.Tensor, + q_out: torch.Tensor, + kv_out: torch.Tensor, + pe_cache: torch.Tensor, + ki_out: torch.Tensor, + x_rmsnorm_out: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """ + rmsnorm_projx_wqkvia operation. + + Args: + x_in: Input tensor. + wqkv_a: QKV weights. + wqkv_a_scales: QKV scales. + rmsnorm_gamma: RMSNorm gamma. + cur_pos: Current position. + q_out: Q output tensor. + kv_out: KV output tensor. + pe_cache: PE cache tensor. + ki_out: Ki output tensor. + x_rmsnorm_out: RMSNorm output tensor. + profile_logs: Profile logs tensor. + """ + torch.ops.tilert.rmsnorm_proj_qa_kva_ki_op( + x_in, + wqkv_a, + wqkv_a_scales, + rmsnorm_gamma, + cur_pos, + q_out, + kv_out, + pe_cache, + ki_out, + x_rmsnorm_out, + profile_logs, + ) + + +def projx_wqkvia( + x_quant: torch.Tensor, + x_scale: torch.Tensor, + wqkvia: torch.Tensor, + cur_pos: torch.Tensor, + out_q: torch.Tensor, + out_kv: torch.Tensor, + pe_cache: torch.Tensor, + out_ki: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """ + Define the ProjXWQKVIa operation. + + Args: + x_quant: Input tensor. + x_scale: Weight tensor. + wqkvia: Weight tensor. + cur_pos: Current position tensor. + out_q: Output tensor. + out_kv: Output tensor. + pe_cache: Output tensor. + out_ki: Output tensor. + profile_logs: Profile logs tensor. + """ + dim = x_quant.shape[-1] + if dim == 6144: + func_call = torch.ops.tilert.projx_wqkvia_glm5 + elif dim == 7168: + func_call = torch.ops.tilert.projx_wqkvia_op + else: + raise ValueError(f"Unsupported dimension: {dim}") + func_call(x_quant, x_scale, wqkvia, cur_pos, out_q, out_kv, pe_cache, out_ki, profile_logs) + + +class RMSNormProjxWqkviaAlgorithm(Enum): + """RMSNormProjxWqkvia algorithm""" + + GENERAL = "general" # fused + DECOUPLED = "decoupled" # rmsnorm_quant + projx_wqkvia + + +class RMSNormProjQAKVAKIWeightsConverter: + """Weights converter class.""" + + @staticmethod + def _swizzle_mma_16x32(mat_in: torch.Tensor) -> torch.Tensor: + assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32 + # PTX isa fig.88 + pre_shape = mat_in.shape[:-2] + mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4) + return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2) + + @staticmethod + def tilert_to_common( + tilert_wqkv_a: torch.Tensor, + tilert_wqkv_a_scales: torch.Tensor, + tilert_attn_norm_weight: torch.Tensor, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + """ + Convert tilert weights to common weights. + + Args: + tilert_wqkv_a: Tilert weight tensor. + tilert_wqkv_a_scales: Tilert weight scale tensor. + tilert_attn_norm_weight: Tilert attention norm weight tensor. + Returns: + tuple: Common weights. + """ + wq_a = tilert_wqkv_a[:1536] # 1536, 7168 + wkv_a = tilert_wqkv_a[1536 : 1536 + 576] # 576, 7168 + wk = tilert_wqkv_a[1536 + 576 :] # 128, 7168 + + wqkv_a_scales_0 = tilert_wqkv_a_scales[:128, :].reshape(16, 8, 64) + wqkv_a_scales_0 = wqkv_a_scales_0[:, 0, :].reshape(16, 64) + wqkv_a_scales_1 = tilert_wqkv_a_scales[128:129, :] # 1, 64 + wqkv_a_scales_2 = tilert_wqkv_a_scales[129:, :] # 1, 64 + wqkv_a_scales_swizzled = torch.cat( + [wqkv_a_scales_0, wqkv_a_scales_1, wqkv_a_scales_2], dim=0 + ) + wqkv_scales = torch.zeros( + (18, 56), dtype=torch.bfloat16, device=tilert_wqkv_a_scales.device + ) + + for i in range(64): + if ((i % 8) * 8 + i // 8) < 56: + wqkv_scales[:, ((i % 8) * 8 + i // 8)] = wqkv_a_scales_swizzled[:, i] + wq_a_scale = wqkv_scales[:12, :] # 12, 56 + wkv_a_scale = wqkv_scales[12:17, :] # 5, 56 + wk_scale = wqkv_scales[17:, :] # 1, 56 + + attn_norm_weight = tilert_attn_norm_weight + return wq_a, wq_a_scale, wkv_a, wkv_a_scale, wk, wk_scale, attn_norm_weight + + @staticmethod + def common_to_tilert( + wq_a: torch.Tensor, + wq_a_scale: torch.Tensor, + wkv_a: torch.Tensor, + wkv_a_scale: torch.Tensor, + wk: torch.Tensor, + wk_scale: torch.Tensor, + attn_norm_weight: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Convert common weights to tilert weights. + + Args: + wq_a: Common weight tensor. + wq_a_scale: Common weight scale tensor. + wkv_a: Common weight tensor. + wkv_a_scale: Common weight scale tensor. + wk: Common weight tensor. + wk_scale: Common weight scale tensor. + attn_norm_weight: Common attention norm weight tensor. + Returns: + tuple: Tilert weights. + """ + wqkv_a = torch.cat([wq_a, wkv_a, wk], dim=0) + wqkv_a_scales_raw = torch.cat([wq_a_scale, wkv_a_scale, wk_scale], dim=0) + + wqkv_a_scales = torch.zeros((18, 64), dtype=torch.bfloat16, device=wq_a_scale.device) + for i in range(64): + wqkv_a_scales[:, i] = wqkv_a_scales_raw[:, ((i % 8) * 8 + i // 8) % 56] + if ((i % 8) * 8 + i // 8) >= 56: + wqkv_a_scales[:, i] = 0.0 + wqkv_a_scales_0 = wqkv_a_scales[:16, :] + wqkv_a_scales_1 = wqkv_a_scales[16:17, :] + wqkv_a_scales_2 = wqkv_a_scales[17:, :] + + wqkv_a_scales_0 = wqkv_a_scales_0.reshape((16, 1, 64)).repeat(1, 8, 1).reshape(-1, 64) + wqkv_a_scales = torch.cat([wqkv_a_scales_0, wqkv_a_scales_1, wqkv_a_scales_2], dim=0) + assert wqkv_a_scales.shape == (130, 64) + return wqkv_a.contiguous(), wqkv_a_scales.contiguous(), attn_norm_weight.clone() + + @staticmethod + def common_to_tilert_fp8( + wq_a: torch.Tensor, + wq_a_scale: torch.Tensor, + wkv_a: torch.Tensor, + wkv_a_scale: torch.Tensor, + wk: torch.Tensor, + wk_scale: torch.Tensor, + attn_norm_weight: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert common weights to tilert weights. + + Args: + wq_a: Common weight tensor. + wq_a_scale: Common weight scale tensor. + wkv_a: Common weight tensor. + wkv_a_scale: Common weight scale tensor. + wk: Common weight tensor. + wk_scale: Common weight scale tensor. + attn_norm_weight: Common attention norm weight tensor. + Returns: + tuple: Tilert fp8 weights. + """ + wq_a_raw: torch.Tensor = wq_a.detach().clone() + wkv_a_raw: torch.Tensor = wkv_a.detach().clone() + wq_a_raw = torch.cat([wq_a_raw, wkv_a_raw[:512], wk, wkv_a_raw[512:]], dim=0) + + wq_a_raw = wq_a_raw.reshape(35, 64, 14, 512) + wq_a_raw = wq_a_raw.permute(0, 2, 1, 3) + + wq_a_raw = wq_a_raw.reshape(35, 14, 16, 4, 4, 128) + wq_a_copy = wq_a_raw.contiguous().clone() + wq_a_raw[:, :, 1::2, :, :, :64] = wq_a_copy[:, :, 1::2, :, :, 64:] + wq_a_raw[:, :, 1::2, :, :, 64:] = wq_a_copy[:, :, 1::2, :, :, :64] + wq_a_raw = wq_a_raw.reshape(35, 14, 16, 4, 4, 2, 64) + wq_a_copy = wq_a_raw.contiguous().clone() + wq_a_raw[:, :, :, 2:, :, :, :32] = wq_a_copy[:, :, :, 2:, :, :, 32:] + wq_a_raw[:, :, :, 2:, :, :, 32:] = wq_a_copy[:, :, :, 2:, :, :, :32] + wq_a_raw = wq_a_raw.reshape(35, 14, 16, 4, 4, 2, 2, 32) + wq_a_copy = wq_a_raw.contiguous().clone() + wq_a_raw[:, :, :, 1::2, :, :, :, :16] = wq_a_copy[:, :, :, 1::2, :, :, :, 16:] + wq_a_raw[:, :, :, 1::2, :, :, :, 16:] = wq_a_copy[:, :, :, 1::2, :, :, :, :16] + + wq_a_raw = wq_a_raw.reshape(35, 14, 16, 4, 4, 128) + wq_a_raw = wq_a_raw.permute(0, 1, 4, 2, 3, 5).reshape(35, 14, -1).contiguous() + wq_a_raw = wq_a_raw.reshape(35, 14, -1).contiguous() + + wq_s_raw: torch.Tensor = wq_a_scale.detach().clone() + wkv_s_raw: torch.Tensor = wkv_a_scale.detach().clone() + wq_s_raw = torch.cat([wq_s_raw, wkv_s_raw[:4], wk_scale, wkv_s_raw[4:]], dim=0) + wq_s_raw = wq_s_raw.reshape(18, 1, 14, 4).repeat(1, 2, 1, 1).reshape(36, 1, 14, 4) + wq_s_raw = wq_s_raw[:35].reshape(35, 14, -1).contiguous() + wq_s_raw = wq_s_raw.view(torch.float8_e4m3fn) + wq_as_raw = torch.cat([wq_a_raw, wq_s_raw], dim=-1) + + return wq_as_raw.contiguous(), attn_norm_weight.clone() + + @staticmethod + def common_to_tilert_native_bf16( + wq_a: torch.Tensor, + wq_a_scale: torch.Tensor, + wkv_a: torch.Tensor, + wkv_a_scale: torch.Tensor, + wk: torch.Tensor, + wk_scale: torch.Tensor, + attn_norm_weight: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert common weights to weights for tilert native bf16 op. + + Args: + wq_a: Common weight tensor. + wq_a_scale: Common weight scale tensor. + wkv_a: Common weight tensor. + wkv_a_scale: Common weight scale tensor. + wk: Common weight tensor. + wk_scale: Common weight scale tensor. + attn_norm_weight: Common attention norm weight tensor. + Returns: + tuple: Tilert weights for native bf16 op. + """ + wq_a_scale = wq_a_scale.reshape((12, 56, 1)).repeat(1, 1, 128).reshape((12, 1, 7168)) + wq_a_scale = wq_a_scale.repeat(1, 128, 1).reshape((1536, 7168)) + wkv_a_scale = wkv_a_scale.reshape((5, 56, 1)).repeat(1, 1, 128).reshape((5, 1, 7168)) + wkv_a_scale = wkv_a_scale.repeat(1, 128, 1).reshape((-1, 7168)) + wkv_a_scale = wkv_a_scale[:576] + wk_scale = wk_scale.reshape((1, 56, 1)).repeat(1, 1, 128).reshape((1, 1, 7168)) + wk_scale = wk_scale.repeat(1, 128, 1).reshape((128, 7168)) + wq_a = wq_a.reshape((1536, 7168)).float() * wq_a_scale.float() + wkv_a = wkv_a.reshape((576, 7168)).float() * wkv_a_scale.float() + wk = wk.reshape((128, 7168)).float() * wk_scale.float() + weights = torch.cat([wq_a, wkv_a, wk], dim=0) + assert weights.shape == (1536 + 576 + 128, 7168) + return weights.to(torch.bfloat16).contiguous(), attn_norm_weight.clone() + + @staticmethod + def common_to_tilert_native_bf16_warp_gemv( + wq_a: torch.Tensor, + wq_a_scale: torch.Tensor, + wkv_a: torch.Tensor, + wkv_a_scale: torch.Tensor, + wk: torch.Tensor, + wk_scale: torch.Tensor, + attn_norm_weight: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert common weights to weights for tilert native bf16 warp gemv op. + + Args: + wq_a: Common weight tensor. + wq_a_scale: Common weight scale tensor. + wkv_a: Common weight tensor. + wkv_a_scale: Common weight scale tensor. + wk: Common weight tensor. + wk_scale: Common weight scale tensor. + attn_norm_weight: Common attention norm weight tensor. + Returns: + tuple: Tilert weights for native bf16 warp gemv op. + """ + wq_a_scale = wq_a_scale.reshape((12, 56, 1)).repeat(1, 1, 128).reshape((12, 1, 7168)) + wq_a_scale = wq_a_scale.repeat(1, 128, 1).reshape((1536, 7168)) + wkv_a_scale = wkv_a_scale.reshape((5, 56, 1)).repeat(1, 1, 128).reshape((5, 1, 7168)) + wkv_a_scale = wkv_a_scale.repeat(1, 128, 1).reshape((-1, 7168)) + wkv_a_scale = wkv_a_scale[:576] + wk_scale = wk_scale.reshape((1, 56, 1)).repeat(1, 1, 128).reshape((1, 1, 7168)) + wk_scale = wk_scale.repeat(1, 128, 1).reshape((128, 7168)) + wq_a = wq_a.reshape((1536, 7168)).float() * wq_a_scale.float() + wkv_a = wkv_a.reshape((576, 7168)).float() * wkv_a_scale.float() + wk = wk.reshape((128, 7168)).float() * wk_scale.float() + # concatenate the weights + weights = torch.cat([wq_a, wkv_a, wk], dim=0) + assert weights.shape == (1536 + 576 + 128, 7168) + + weights = weights.reshape(140, 16, 7, 1024) + weights = weights.transpose(1, 2) # 140, 7, 16, 1024 + return weights.to(torch.bfloat16).contiguous(), attn_norm_weight.clone() + + @staticmethod + def common_to_tilert_dequant_bf16( + wq_a: torch.Tensor, + wq_a_scale: torch.Tensor, + wkv_a: torch.Tensor, + wkv_a_scale: torch.Tensor, + wk: torch.Tensor, + wk_scale: torch.Tensor, + attn_norm_weight: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert common weights to weights for tilert dequant bf16 op. + + Args: + wq_a: Common weight tensor. + wq_a_scale: Common weight scale tensor. + wkv_a: Common weight tensor. + wkv_a_scale: Common weight scale tensor. + wk: Common weight tensor. + wk_scale: Common weight scale tensor. + attn_norm_weight: Common attention norm weight tensor. + Returns: + tuple: Tilert weights for dequant bf16 op. + """ + wq_a = wq_a.reshape((384, 4, 7168)) + wkv_a = wkv_a.reshape((144, 4, 7168)) + wk = wk.reshape((32, 4, 7168)) + wqkv = torch.cat([wq_a, wkv_a, wk], dim=0).reshape(140, 4, 4 * 7168) + + wq_a_scale = wq_a_scale.reshape((12, 1, 56)).repeat(1, 32, 1).reshape((384, 1, 56)) + wkv_a_scale = wkv_a_scale.reshape((5, 1, 56)).repeat(1, 32, 1).reshape((160, 1, 56))[:144] + wk_scale = wk_scale.reshape((1, 1, 56)).repeat(1, 32, 1).reshape((32, 1, 56)) + wqkv_scales = torch.cat([wq_a_scale, wkv_a_scale, wk_scale], dim=0).reshape(140, 4, 56) + wqkv_scales_swizzled = torch.zeros(140, 4, 64, dtype=torch.bfloat16, device=wq_a.device) + # swizzle + for i in range(64): + wqkv_scales_swizzled[..., i] = wqkv_scales[..., ((i % 8) * 8 + i // 8) % 56] + weights = torch.zeros( + 140, 4, 4 * 7168 + 64 * 2, dtype=torch.float8_e4m3fn, device=wq_a.device + ) + weights_part = weights[:, :, : 4 * 7168] + scales_part = weights[:, :, 4 * 7168 :] + weights_part.copy_(wqkv) + scales_part.copy_(wqkv_scales_swizzled.view(dtype=torch.float8_e4m3fn)) + return weights.contiguous(), attn_norm_weight.clone() + + @staticmethod + def common_to_tilert_fp8_mma( + wq_a: torch.Tensor, + wq_a_scale: torch.Tensor, + wkv_a: torch.Tensor, + wkv_a_scale: torch.Tensor, + wk: torch.Tensor, + wk_scale: torch.Tensor, + rmsnorm_gamma: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert common weights to weights for tilert fp8 mma op. + + Args: + wq_a: Common weight tensor. + wq_a_scale: Common weight scale tensor. + wkv_a: Common weight tensor. + wkv_a_scale: Common weight scale tensor. + wk: Common weight tensor. + wk_scale: Common weight scale tensor. + rmsnorm_gamma: Common rmsnorm gamma tensor. + Returns: + tuple: Tilert weights for fp8 mma op. + """ + assert wq_a.shape == (1536, 7168) + assert wq_a_scale.shape == (12, 56) + assert wkv_a.shape == (576, 7168) + assert wkv_a_scale.shape == (5, 56) + assert wk.shape == (128, 7168) + assert wk_scale.shape == (1, 56) + wq_a = wq_a.reshape(96, 16, 7168) + wq_a_scale = wq_a_scale.reshape(12, 1, 56).repeat(1, 8, 1).reshape(96, 56) + wkv_a = wkv_a.reshape(36, 16, 7168) + wkv_a_scale = wkv_a_scale.reshape(5, 1, 56).repeat(1, 8, 1).reshape(40, 56) + wkv_a_scale = wkv_a_scale[:36] + + wk = wk.reshape(8, 16, 7168) + wk_scale = wk_scale.reshape(1, 1, 56).repeat(1, 8, 1).reshape(8, 56) + wqkvia = torch.cat([wq_a, wkv_a, wk], dim=0) # 140, 7168 + wqkvia_scale = torch.cat([wq_a_scale, wkv_a_scale, wk_scale], dim=0) # 140, 56 + + wqkvia_0 = wqkvia[..., :2048] + wqkvia_0_scale = wqkvia_scale[..., :16].contiguous().view(torch.float8_e4m3fn) + wqkvia_1 = wqkvia[..., 2048:4096] + wqkvia_1_scale = wqkvia_scale[..., 16:32].contiguous().view(torch.float8_e4m3fn) + wqkvia_2 = wqkvia[..., 4096:6144] + wqkvia_2_scale = wqkvia_scale[..., 32:48].contiguous().view(torch.float8_e4m3fn) + wqkvia_3 = wqkvia[..., 6144:7168] + wqkvia_3_scale = wqkvia_scale[..., 48:56].contiguous().view(torch.float8_e4m3fn) + + wqkvia_0 = wqkvia_0.reshape(140, 16, 64, 32).transpose(1, 2) + wqkvia_0 = RMSNormProjQAKVAKIWeightsConverter._swizzle_mma_16x32(wqkvia_0) + wqkvia_0 = wqkvia_0.reshape(140, 16 * 2048) + + wqkvia_1 = wqkvia_1.reshape(140, 16, 64, 32).transpose(1, 2) + wqkvia_1 = RMSNormProjQAKVAKIWeightsConverter._swizzle_mma_16x32(wqkvia_1) + wqkvia_1 = wqkvia_1.reshape(140, 16 * 2048) + + wqkvia_2 = wqkvia_2.reshape(140, 16, 64, 32).transpose(1, 2) + wqkvia_2 = RMSNormProjQAKVAKIWeightsConverter._swizzle_mma_16x32(wqkvia_2) + wqkvia_2 = wqkvia_2.reshape(140, 16 * 2048) + + wqkvia_3 = wqkvia_3.reshape(140, 16, 32, 32).transpose(1, 2) + wqkvia_3 = RMSNormProjQAKVAKIWeightsConverter._swizzle_mma_16x32(wqkvia_3) + wqkvia_3 = wqkvia_3.reshape(140, 16 * 1024) + padding_scale0 = torch.zeros((140, 48), dtype=torch.bfloat16, device=wq_a.device).view( + torch.float8_e4m3fn + ) + padding_scale1 = torch.zeros((140, 48), dtype=torch.bfloat16, device=wq_a.device).view( + torch.float8_e4m3fn + ) + padding_scale2 = torch.zeros((140, 48), dtype=torch.bfloat16, device=wq_a.device).view( + torch.float8_e4m3fn + ) + padding_scale3 = torch.zeros((140, 56), dtype=torch.bfloat16, device=wq_a.device).view( + torch.float8_e4m3fn + ) + wqkvia = torch.cat( + [ + wqkvia_0, + wqkvia_0_scale, + padding_scale0, + wqkvia_1, + wqkvia_1_scale, + padding_scale1, + wqkvia_2, + wqkvia_2_scale, + padding_scale2, + wqkvia_3, + wqkvia_3_scale, + padding_scale3, + ], + dim=1, + ) + + return wqkvia.contiguous(), rmsnorm_gamma.contiguous() + + +class RMSNormProjxWqkviaWeightsConverter(TilertWeightsConverter): + """RMSNormProjxWqkvia weights converter""" + + @staticmethod + def _swizzle_qmma_16x32(mat_in: torch.Tensor) -> torch.Tensor: + assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32 + assert mat_in.dtype == torch.float8_e4m3fn + # PTX isa fig.88 + pre_shape = mat_in.shape[:-2] + mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4) + return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2) + + def convert_to_general(self, weights: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert the weights to general format. + + Args: + weights: List of weights. + + Returns: + Tuple of weights. + """ + # Specialized for DS v3.2 model + args = self.model_args + assert ( + args.arch_name == "deepseek_v3_2" + ), f"arch_name must be deepseek_v3_2, but got {args.arch_name}" + with torch.inference_mode(): + x_rmsnorm_gamma, wq_a, wq_a_scale, wkv_a, wkv_a_scale, wk, wk_scale = weights + q_lora_rank_scale_dim = args.q_lora_rank // args.block_size + kv_lora_rank_scale_dim = args.kv_lora_rank // args.block_size + 1 + x_scale_dim = args.dim // args.block_size + + wq_a_scale = ( + wq_a_scale.reshape((q_lora_rank_scale_dim, x_scale_dim, 1)) + .repeat(1, 1, args.block_size) + .reshape((q_lora_rank_scale_dim, 1, args.dim)) + ) + wq_a_scale = wq_a_scale.repeat(1, args.block_size, 1).reshape( + (args.q_lora_rank, args.dim) + ) + wkv_a_scale = ( + wkv_a_scale.reshape((kv_lora_rank_scale_dim, x_scale_dim, 1)) + .repeat(1, 1, args.block_size) + .reshape((kv_lora_rank_scale_dim, 1, args.dim)) + ) + wkv_a_scale = wkv_a_scale.repeat(1, args.block_size, 1).reshape((-1, args.dim)) + wkv_a_scale = wkv_a_scale[: args.kv_lora_rank + args.qk_rope_head_dim] + wk_scale = ( + wk_scale.reshape((1, x_scale_dim, 1)) + .repeat(1, 1, args.block_size) + .reshape((1, 1, args.dim)) + ) + wk_scale = wk_scale.repeat(1, args.block_size, 1).reshape( + (args.index_head_dim, args.dim) + ) + wq_a = wq_a.reshape((args.q_lora_rank, args.dim)).float() * wq_a_scale.float() + wkv_a = ( + wkv_a.reshape((args.kv_lora_rank + args.qk_rope_head_dim, args.dim)).float() + * wkv_a_scale.float() + ) + wk = wk.reshape((args.index_head_dim, args.dim)).float() * wk_scale.float() + # concatenate the weights + weights_tensor: torch.Tensor = torch.cat([wq_a, wkv_a, wk], dim=0) + assert weights_tensor.shape == ( + args.q_lora_rank + args.kv_lora_rank + args.qk_rope_head_dim + args.index_head_dim, + args.dim, + ) + # hard-coded scheduling: reshape to 140, 16, 7, 1024 + weights_tensor = weights_tensor.reshape(140, 16, 7, 1024) + weights_tensor = weights_tensor.transpose(1, 2) # 140, 7, 16, 1024 + return x_rmsnorm_gamma, weights_tensor.to(torch.bfloat16).contiguous() + + def convert_to_decoupled( + self, weights: list[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert the weights to decoupled format. + + Args: + weights: List of weights. + + Returns: + Tuple of weights. + """ + arch_name = self.model_args.arch_name + wqkvia_and_scales = None + with torch.inference_mode(): + x_rmsnorm_gamma, wq_a, wq_a_scale, wkv_a, wkv_a_scale, wk, wk_scale = weights + # Ensure the scales are in bfloat16 + if arch_name == "deepseek_v3_2": # DS v3.2 + # Ensure the scales are in bfloat16 for DS v3.2 + wq_a_scale = wq_a_scale.to(torch.bfloat16) + wkv_a_scale = wkv_a_scale.to(torch.bfloat16) + wk_scale = wk_scale.to(torch.bfloat16) + assert wq_a.shape == (1536, 7168) + assert wq_a_scale.shape == (12, 56) + assert wkv_a.shape == (576, 7168) + assert wkv_a_scale.shape == (5, 56) + assert wk.shape == (128, 7168) + assert wk_scale.shape == (1, 56) + wq_a = wq_a.reshape(96, 16, 7168) + wq_a_scale = wq_a_scale.reshape(12, 1, 56).repeat(1, 8, 1).reshape(96, 56) + wkv_a = wkv_a.reshape(36, 16, 7168) + wkv_a_scale = wkv_a_scale.reshape(5, 1, 56).repeat(1, 8, 1).reshape(40, 56) + wkv_a_scale = wkv_a_scale[:36] + + wk = wk.reshape(8, 16, 7168) + wk_scale = wk_scale.reshape(1, 1, 56).repeat(1, 8, 1).reshape(8, 56) + wqkvia = torch.cat([wq_a, wkv_a, wk], dim=0) # 140, 7168 + wqkvia_scale = torch.cat([wq_a_scale, wkv_a_scale, wk_scale], dim=0) # 140, 56 + + wqkvia_0 = wqkvia[..., :2048] + wqkvia_0_scale = wqkvia_scale[..., :16].contiguous().view(torch.float8_e4m3fn) + wqkvia_1 = wqkvia[..., 2048:4096] + wqkvia_1_scale = wqkvia_scale[..., 16:32].contiguous().view(torch.float8_e4m3fn) + wqkvia_2 = wqkvia[..., 4096:6144] + wqkvia_2_scale = wqkvia_scale[..., 32:48].contiguous().view(torch.float8_e4m3fn) + wqkvia_3 = wqkvia[..., 6144:7168] + wqkvia_3_scale = wqkvia_scale[..., 48:56].contiguous().view(torch.float8_e4m3fn) + + wqkvia_0 = wqkvia_0.reshape(140, 16, 64, 32).transpose(1, 2) + wqkvia_0 = self._swizzle_qmma_16x32(wqkvia_0) + wqkvia_0 = wqkvia_0.reshape(140, 16 * 2048) + + wqkvia_1 = wqkvia_1.reshape(140, 16, 64, 32).transpose(1, 2) + wqkvia_1 = self._swizzle_qmma_16x32(wqkvia_1) + wqkvia_1 = wqkvia_1.reshape(140, 16 * 2048) + + wqkvia_2 = wqkvia_2.reshape(140, 16, 64, 32).transpose(1, 2) + wqkvia_2 = self._swizzle_qmma_16x32(wqkvia_2) + wqkvia_2 = wqkvia_2.reshape(140, 16 * 2048) + + wqkvia_3 = wqkvia_3.reshape(140, 16, 32, 32).transpose(1, 2) + wqkvia_3 = self._swizzle_qmma_16x32(wqkvia_3) + wqkvia_3 = wqkvia_3.reshape(140, 16 * 1024) + padding_scale0 = torch.zeros( + (140, 48), dtype=torch.bfloat16, device=wq_a.device + ).view(torch.float8_e4m3fn) + padding_scale1 = torch.zeros( + (140, 48), dtype=torch.bfloat16, device=wq_a.device + ).view(torch.float8_e4m3fn) + padding_scale2 = torch.zeros( + (140, 48), dtype=torch.bfloat16, device=wq_a.device + ).view(torch.float8_e4m3fn) + padding_scale3 = torch.zeros( + (140, 56), dtype=torch.bfloat16, device=wq_a.device + ).view(torch.float8_e4m3fn) + wqkvia_and_scales = torch.cat( + [ + wqkvia_0, + wqkvia_0_scale, + padding_scale0, + wqkvia_1, + wqkvia_1_scale, + padding_scale1, + wqkvia_2, + wqkvia_2_scale, + padding_scale2, + wqkvia_3, + wqkvia_3_scale, + padding_scale3, + ], + dim=1, + ) + elif arch_name == "glm_5": # GLM5 + # Ensure the scales are in float32 for DS v3.2 + if wq_a_scale.dtype != torch.float32: + # TODO: remove this after the source weights are converted to float32 + print( + "Warning: RMSNormProjxWqkviaWeightsConverter: " + + "wq_a_scale is not in float32, converting to float32." + ) + wq_a_scale = wq_a_scale.to(torch.float32) + wkv_a_scale = wkv_a_scale.to(torch.float32) + wk_scale = wk_scale.to(torch.float32) + # (2048 + 576 + 128, 6144) + wqkvia = torch.cat([wq_a, wkv_a, wk], dim=0).reshape(86, 32, 6144) + # (16+5+1, 48) + wq_a_scale = wq_a_scale.reshape((16, 1, 48)).repeat(1, 4, 1).reshape(64, 48) + wkv_a_scale = wkv_a_scale.reshape((5, 1, 48)).repeat(1, 4, 1).reshape(20, 48)[:18] + wk_scale = wk_scale.reshape((1, 1, 48)).repeat(1, 4, 1).reshape(4, 48) + wqkvia_scales = torch.cat([wq_a_scale, wkv_a_scale, wk_scale], dim=0) # (86, 48) + wqkvia = wqkvia.reshape(86, 32, 6, 1024).transpose(1, 2).reshape(86, 6, 2, 16, 1024) + wqkvia = wqkvia.reshape(86, 6, 2, 16, 32, 32).transpose(3, 4) + wqkvia = self._swizzle_qmma_16x32(wqkvia).reshape(86, 6, 32 * 1024) + wqkvia_scales = wqkvia_scales.reshape(86, 6, 8).view(torch.float8_e4m3fn) + wqkvia_padding = torch.zeros( + (86, 6, 128 - wqkvia_scales.shape[-1]), + dtype=torch.float8_e4m3fn, + device=wq_a.device, + ) + wqkvia_and_scales = torch.cat([wqkvia, wqkvia_scales, wqkvia_padding], dim=-1) + else: + raise ValueError(f"Unsupported architecture: {arch_name}") + assert wqkvia_and_scales is not None + return x_rmsnorm_gamma.float(), wqkvia_and_scales.contiguous() + + +@dataclass +class RMSNormProjxWqkviaRefWeightsAlias: + """Reference weights alias for RMSNormProjxWqkvia.""" + + x_rmsnorm_gamma = "input_layernorm.weight" + q_a_weights = "self_attn.q_a_proj.weight" + q_a_scales = "self_attn.q_a_proj.weight_scale_inv" + kv_a_with_mqa_weights = "self_attn.kv_a_proj_with_mqa.weight" + kv_a_with_mqa_scales = "self_attn.kv_a_proj_with_mqa.weight_scale_inv" + wk_weights = "self_attn.indexer.wk.weight" + wk_scales = "self_attn.indexer.wk.weight_scale_inv" + + @property + def ref_tensor_alias(self) -> list[str]: + return [ + self.x_rmsnorm_gamma, + self.q_a_weights, + self.q_a_scales, + self.kv_a_with_mqa_weights, + self.kv_a_with_mqa_scales, + self.wk_weights, + self.wk_scales, + ] + + def __call__(self) -> list[str]: + return self.ref_tensor_alias + + +@dataclass +class RMSNormProjxWqkviaTilertWeightsAlias: + """TileRT weights alias for RMSNormProjxWqkvia.""" + + x_rmsnorm_gamma = "x_rmsnorm_gamma" + q_a_weights = "q_a_weights" + q_a_scales = "q_a_scales" + kv_a_with_mqa_weights = "kv_a_with_mqa_weights" + kv_a_with_mqa_scales = "kv_a_with_mqa_scales" + wk_weights = "wk_weights" + wk_scales = "wk_scales" + + @property + def tilert_tensor_alias(self) -> list[str]: + return [ + self.x_rmsnorm_gamma, + self.q_a_weights, + self.q_a_scales, + self.kv_a_with_mqa_weights, + self.kv_a_with_mqa_scales, + self.wk_weights, + self.wk_scales, + ] + + def __call__(self) -> list[str]: + return self.tilert_tensor_alias + + +class RMSNormProjxWqkvia(TileRTModule): + """RMSNormProjxWqkvia module""" + + def __init__( + self, + model_args: ModelArgs, + num_devices: int, + device_id: int, + ref_weights_alias: RMSNormProjxWqkviaRefWeightsAlias | None = None, + algorithm: RMSNormProjxWqkviaAlgorithm = RMSNormProjxWqkviaAlgorithm.GENERAL, + ): + super().__init__( + self.__class__.__name__, + model_args=model_args, + num_devices=num_devices, + device_id=device_id, + ) + + self.tilert_weights_alias = RMSNormProjxWqkviaTilertWeightsAlias() + self.ref_weights_alias = ( + ref_weights_alias + if ref_weights_alias is not None + else RMSNormProjxWqkviaRefWeightsAlias() + ) + + self.arch_name = self.model_args.arch_name + self.dim = self.model_args.dim + self.q_lora_rank = self.model_args.q_lora_rank + self.kv_lora_rank = self.model_args.kv_lora_rank + self.qk_rope_head_dim = self.model_args.qk_rope_head_dim + self.idx_head_dim = self.model_args.index_head_dim + self.block_size = self.model_args.block_size + self.eps = self.model_args.eps + self.algorithm: RMSNormProjxWqkviaAlgorithm = algorithm + + # reference weights + self.ref_norm_gamma: torch.Tensor | None = None + self.ref_wq_a: torch.Tensor | None = None + self.ref_wkv_a: torch.Tensor | None = None + self.ref_wk: torch.Tensor | None = None + + # tilert weights + self.tilert_norm_gamma: torch.Tensor | None = None + self.tilert_wqkv_a: torch.Tensor | None = None + # Legacy scale tensor for compatibility, to be removed in the future + self.tilert_wqkv_a_scales = torch.zeros((130, 64), dtype=torch.bfloat16) + + # tilert vars + self.x_rmsnorm_out: torch.Tensor | None = None + self.q_out: torch.Tensor | None = None + self.kv_out: torch.Tensor | None = None + self.ki_out: torch.Tensor | None = None + self.x_rmsnorm_quant_out: torch.Tensor | None = None + self.x_rmsnorm_quant_scale_out: torch.Tensor | None = None + + self.profile_logs: torch.Tensor | None = None + self.is_init = False + + # tilert_funcs + self.rmsnorm_proj_func: Callable | None = None + self.rmsnorm_func: Callable | None = None + self.proj_func: Callable | None = None + + if self.arch_name == "deepseek_v3_2": + self.rmsnorm_proj_func = rmsnorm_projx_wqkvia + self.rmsnorm_func = rmsnorm_quant + self.proj_func = projx_wqkvia + elif self.arch_name == "glm_5": + # Lazy import to avoid circular import + self.rmsnorm_proj_func = None + self.rmsnorm_func = rmsnorm_quant + self.proj_func = projx_wqkvia + else: + raise ValueError(f"Unsupported architecture: {self.arch_name}") + + # tilert tensor aliases (3 output weight names for get_weights_list) + self.tilert_tensor_alias: list[str] = [ + "x_rmsnorm_gamma", + "qkv_wa_weights", + "qkv_wa_scales", + ] + + def get_weights_list(self) -> list[torch.Tensor]: + """ + Get the weights list. + + Returns: + List of weights. + """ + assert self.algorithm is not None, "Algorithm is not set" + if self.algorithm == RMSNormProjxWqkviaAlgorithm.GENERAL: + return [self.tilert_norm_gamma, self.tilert_wqkv_a, self.tilert_wqkv_a_scales] + return [self.tilert_norm_gamma, self.tilert_wqkv_a] + + def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Device sharding. + + Args: + input_layernorm_weight: Input layernorm weight. + q_a_proj_weight: Q A proj weight. + q_a_proj_weight_scale: Q A proj weight scale. + kv_a_proj_weight: KV A proj weight. + kv_a_proj_weight_scale: KV A proj weight scale. + indexer_wk_weight: Indexer WK weight. + indexer_wk_weight_scale: Indexer WK weight scale. + + Returns: + Tuple of weights. + """ + # repeat n times for device sharding + # Using float to support both bfloat16 and float + input_layernorm_weight = ( + weights_map[self.ref_weights_alias.x_rmsnorm_gamma][None, ...] + .float() + .repeat(self.num_devices, 1) + ) + q_a_proj_weight = weights_map[self.ref_weights_alias.q_a_weights][None, ...].repeat( + self.num_devices, 1, 1 + ) + q_a_proj_weight_scale = weights_map[self.ref_weights_alias.q_a_scales][None, ...].repeat( + self.num_devices, 1, 1 + ) + kv_a_proj_weight = weights_map[self.ref_weights_alias.kv_a_with_mqa_weights][ + None, ... + ].repeat(self.num_devices, 1, 1) + kv_a_proj_weight_scale = weights_map[self.ref_weights_alias.kv_a_with_mqa_scales][ + None, ... + ].repeat(self.num_devices, 1, 1) + indexer_wk_weight = weights_map[self.ref_weights_alias.wk_weights][None, ...].repeat( + self.num_devices, 1, 1 + ) + indexer_wk_weight_scale = weights_map[self.ref_weights_alias.wk_scales][None, ...].repeat( + self.num_devices, 1, 1 + ) + return { + self.tilert_weights_alias.x_rmsnorm_gamma: input_layernorm_weight, + self.tilert_weights_alias.q_a_weights: q_a_proj_weight, + self.tilert_weights_alias.q_a_scales: q_a_proj_weight_scale, + self.tilert_weights_alias.kv_a_with_mqa_weights: kv_a_proj_weight, + self.tilert_weights_alias.kv_a_with_mqa_scales: kv_a_proj_weight_scale, + self.tilert_weights_alias.wk_weights: indexer_wk_weight, + self.tilert_weights_alias.wk_scales: indexer_wk_weight_scale, + } + + def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """ + Initialize the reference weights. + + Args: + state_dict: State dictionary. + """ + self.ref_norm_gamma = state_dict[self.ref_weights_alias()[0]] + self.ref_wq_a = weight_dequant( + state_dict[self.ref_weights_alias()[1]], state_dict[self.ref_weights_alias()[2]] + ) + self.ref_wkv_a = weight_dequant( + state_dict[self.ref_weights_alias()[3]], state_dict[self.ref_weights_alias()[4]] + ) + self.ref_wk = weight_dequant( + state_dict[self.ref_weights_alias()[5]], state_dict[self.ref_weights_alias()[6]] + ) + + assert self.ref_norm_gamma is not None + assert self.ref_wq_a is not None + assert self.ref_wkv_a is not None + assert self.ref_wk is not None + + assert ( + self.ref_norm_gamma.shape[-1] == self.dim + ), f"norm_gamma shape must be {self.dim}, but got {self.ref_norm_gamma.shape[-1]}" + assert self.ref_wq_a.shape[-2] == self.q_lora_rank, ( + f"wq_a shape must be {self.q_lora_rank}, " + f"but got {self.ref_wq_a.shape[-2]}" + ) + assert ( + self.ref_wq_a.shape[-1] == self.dim + ), f"wq_a shape must be {self.dim}, but got {self.ref_wq_a.shape[-1]}" + assert self.ref_wkv_a.shape[-2] == self.kv_lora_rank + self.qk_rope_head_dim, ( + f"wkv_a shape must be {self.kv_lora_rank + self.qk_rope_head_dim}, " + + f"but got {self.ref_wkv_a.shape[-2]}" + ) + assert ( + self.ref_wkv_a.shape[-1] == self.dim + ), f"wkv_a shape must be {self.dim}, but got {self.ref_wkv_a.shape[-1]}" + assert ( + self.ref_wk.shape[-2] == self.idx_head_dim + ), f"wk shape must be {self.idx_head_dim}, but got {self.ref_wk.shape[-2]}" + assert ( + self.ref_wk.shape[-1] == self.dim + ), f"wk shape must be {self.dim}, but got {self.ref_wk.shape[-1]}" + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """ + Initialize the tilert weights. + + Args: + state_dict: State dictionary. + """ + assert self.algorithm is not None, "Algorithm is not set" + self.tilert_norm_gamma, self.tilert_wqkv_a = RMSNormProjxWqkviaWeightsConverter( + self.model_args, self.num_devices + ).dispatch(self.algorithm, [state_dict[alias] for alias in self.tilert_weights_alias()]) + + def init_tilert_vars(self, batch_size: int, seq_len: int) -> None: + """ + Initialize the tilert variables. + + Args: + batch_size: Batch size. + seq_len: Sequence length. + """ + self.q_out = torch.zeros((batch_size, seq_len, self.q_lora_rank), dtype=torch.bfloat16) + self.kv_out = torch.zeros((batch_size, seq_len, self.kv_lora_rank), dtype=torch.bfloat16) + self.ki_out = torch.zeros((batch_size, seq_len, self.idx_head_dim), dtype=torch.bfloat16) + self.x_rmsnorm_out = torch.zeros((batch_size, seq_len, self.dim), dtype=torch.bfloat16) + if self.algorithm == RMSNormProjxWqkviaAlgorithm.DECOUPLED: + self.x_rmsnorm_quant_out = torch.zeros( + (batch_size, seq_len, self.dim), dtype=torch.float8_e4m3fn + ) + self.x_rmsnorm_quant_scale_out = torch.zeros( + (batch_size, seq_len, self.dim // self.block_size), dtype=torch.float32 + ) + self.profile_logs = get_profile_log_tensor() + self.is_init = True + + def init_random_weights(self) -> None: + """ + Initialize the random weights. + + Returns: + None + """ + q_scale_dim = self.q_lora_rank // self.block_size + kv_scale_dim = (self.kv_lora_rank + self.qk_rope_head_dim) // self.block_size + 1 + wk_scale_dim = self.idx_head_dim // self.block_size + dim_scale_dim = self.dim // self.block_size + scale_dtype = torch.float32 if self.arch_name == "glm_5" else torch.bfloat16 + + tensor_list = [ + torch.randn(self.dim, dtype=torch.float32), + torch.randn(self.q_lora_rank, self.dim, dtype=torch.bfloat16).to(torch.float8_e4m3fn), + torch.randn(q_scale_dim, dim_scale_dim, dtype=scale_dtype), + torch.randn( + self.kv_lora_rank + self.qk_rope_head_dim, self.dim, dtype=torch.bfloat16 + ).to(torch.float8_e4m3fn), + torch.randn(kv_scale_dim, dim_scale_dim, dtype=scale_dtype), + torch.randn(self.idx_head_dim, self.dim, dtype=torch.bfloat16).to(torch.float8_e4m3fn), + torch.randn(wk_scale_dim, dim_scale_dim, dtype=scale_dtype), + ] + ref_state_dict = dict(zip(self.ref_weights_alias(), tensor_list)) + self.init_reference_weights(ref_state_dict) + self.init_tilert_weights( + {_k: _v[self.device_id] for _k, _v in self.device_sharding(ref_state_dict).items()} + ) + + def golden_forward( + self, + x: torch.Tensor, + pe_cache: torch.Tensor, + start_pos: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + assert self.ref_norm_gamma is not None + assert self.ref_wq_a is not None + assert self.ref_wkv_a is not None + assert self.ref_wk is not None + + x_rmsnorm_out = torch.nn.functional.rms_norm( + x.float(), [x.size(-1)], self.ref_norm_gamma, self.eps + ) + + q_out = torch.matmul(x_rmsnorm_out.float(), self.ref_wq_a.transpose(0, 1).float()) + kv_out = torch.matmul(x_rmsnorm_out.float(), self.ref_wkv_a.transpose(0, 1).float()) + kv_out, k_pe = torch.split(kv_out, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + bsz = k_pe.shape[0] + seq_len = k_pe.shape[1] + pe_cache[:bsz, start_pos : start_pos + seq_len].copy_(k_pe.to(torch.bfloat16)) + ki_out = torch.matmul(x_rmsnorm_out.float(), self.ref_wk.transpose(0, 1).float()) + return ( + x_rmsnorm_out.to(torch.bfloat16), + q_out.to(torch.bfloat16), + kv_out.to(torch.bfloat16), + ki_out.to(torch.bfloat16), + ) + + def tilert_forward( + self, + x: torch.Tensor, + pe_cache: torch.Tensor, + start_pos: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if self.algorithm == RMSNormProjxWqkviaAlgorithm.GENERAL: + assert self.rmsnorm_proj_func is not None + self.rmsnorm_proj_func( + x.to(torch.bfloat16), + self.tilert_wqkv_a, + self.tilert_wqkv_a_scales, + self.tilert_norm_gamma, + torch.tensor([start_pos], dtype=torch.int32, device=x.device), + self.q_out, + self.kv_out, + pe_cache, + self.ki_out, + self.x_rmsnorm_out, + self.profile_logs, + ) + elif self.algorithm == RMSNormProjxWqkviaAlgorithm.DECOUPLED: + assert self.rmsnorm_func is not None + assert self.proj_func is not None + self.rmsnorm_func( + x.to(torch.bfloat16), + self.tilert_norm_gamma, + self.x_rmsnorm_out, + self.x_rmsnorm_quant_out, + self.x_rmsnorm_quant_scale_out, + self.profile_logs, + ) + self.proj_func( + self.x_rmsnorm_quant_out, + self.x_rmsnorm_quant_scale_out, + self.tilert_wqkv_a, + torch.tensor([start_pos], dtype=torch.int32, device=x.device), + self.q_out, + self.kv_out, + pe_cache, + self.ki_out, + self.profile_logs, + ) + else: + raise ValueError(f"Unsupported algorithm: {self.algorithm}") + + if self.flag_enable_profiling_log: + parse_profile_log_tensor( + self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)] + ) + return self.x_rmsnorm_out, self.q_out, self.kv_out, self.ki_out + + def __call__( + self, + x: torch.Tensor, + pe_cache: torch.Tensor, + start_pos: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return self.golden_forward(x, pe_cache, start_pos) diff --git a/python/models/deepseek_v3_2/ops/rmsnorm_quant.py b/python/models/deepseek_v3_2/ops/rmsnorm_quant.py new file mode 100644 index 0000000..770db02 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/rmsnorm_quant.py @@ -0,0 +1,73 @@ +"""RMSNormQuant operation module. + +Unified for deepseek_v3_2 (dim=7168) and glm_5 (dim=6144). +Dispatches by hidden_in.shape[-1]: 7168 -> rmsnorm_*_op, 6144 -> rmsnorm_*_glm5_op. +""" + +from __future__ import annotations + +import torch + +__all__ = [ + "BLOCK_SIZE", + "DIM_DEEPSEEK_V3_2", + "DIM_GLM_5", + "rmsnorm_quant", +] + +BLOCK_SIZE = 128 +DIM_DEEPSEEK_V3_2 = 7168 +DIM_GLM_5 = 6144 + + +def rmsnorm_quant( + hidden_in: torch.Tensor, + gamma_in: torch.Tensor, + hidden_out: torch.Tensor, + quant_hidden_out: torch.Tensor | None = None, + quant_hidden_scale_out: torch.Tensor | None = None, + profile_logs: torch.Tensor | None = None, +) -> None: + """ + Rmsnorm with optional activation quantization. + + Unified for deepseek_v3_2 (dim=7168) and glm_5 (dim=6144). Dispatches by + hidden_in.shape[-1]: 7168 -> rmsnorm_op / rmsnorm_quant_op, + 6144 -> rmsnorm_glm5_op / rmsnorm_quant_glm5_op. + + Args: + hidden_in: Input tensor (..., dim). + gamma_in: RMSNorm gamma (dim,). + hidden_out: RMSNorm output (..., dim). + quant_hidden_out: Optional quantized output (..., dim). If None, no quant. + quant_hidden_scale_out: Optional quant scale (..., dim // block_size). If None, no quant. + profile_logs: Optional profile logs tensor. + """ + dim = hidden_in.shape[-1] + if dim == DIM_GLM_5: + glm5_flag = "_glm5" + elif dim == DIM_DEEPSEEK_V3_2: + glm5_flag = "" + else: + raise ValueError( + f"Unsupported hidden_in.shape[-1]: {dim}. " + f"rmsnorm_quant supports {DIM_DEEPSEEK_V3_2} (deepseek_v3_2) or {DIM_GLM_5} (glm_5)." + ) + if quant_hidden_out is None or quant_hidden_scale_out is None: + quant_flag = "" + quant_args = [hidden_in, gamma_in, hidden_out, profile_logs] + else: + quant_flag = "_quant" + quant_args = [ + hidden_in, + gamma_in, + hidden_out, + quant_hidden_out, + quant_hidden_scale_out, + profile_logs, + ] + if profile_logs is None: + raise ValueError("profile_logs is required when calling rmsnorm_quant.") + func_name = f"rmsnorm{quant_flag}{glm5_flag}_op" + func_call = getattr(torch.ops.tilert, func_name) + func_call(*quant_args) diff --git a/python/models/deepseek_v3_2/ops/rmsnorm_up_gate_silu.py b/python/models/deepseek_v3_2/ops/rmsnorm_up_gate_silu.py new file mode 100644 index 0000000..e2f5c59 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/rmsnorm_up_gate_silu.py @@ -0,0 +1,371 @@ +"""RMSNormUpGateSiLU operation module.""" + +from dataclasses import dataclass +from enum import Enum + +import torch +import torch.nn.functional as F + +from tilert.models.base import TileRTModule +from tilert.models.common import weight_dequant +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.models.deepseek_v3_2.ops.expert_sel_up_gate_silu import ( + ExpertSelectUpGateSiLU, + ExpertSelectUpGateSiLUWeightsConverter, +) +from tilert.profiler.utils import parse_profile_log_tensor +from tilert.utils import get_profile_log_tensor + +__all__ = [ + "RMSNormUpGateSiLUAlgorithm", + "RMSNormUpGateSiLU", + "RMSNormUpGateSiLUTilertWeightsAlias", + "rmsnorm_up_gate_silu", +] + + +def rmsnorm_up_gate_silu( + hidden_in: torch.Tensor, + gamma_in: torch.Tensor, + weights_in: torch.Tensor, + hidden_out: torch.Tensor, + profile_logs: torch.Tensor, + compute_kernel_type: str = "fp8mma", +) -> None: + """rmsnorm_up_gate_silu operation.""" + torch.ops.tilert.rmsnorm_up_gate_silu_op( + hidden_in, + gamma_in, + weights_in, + hidden_out, + profile_logs, + compute_kernel_type, + ) + + +class RMSNormUpGateSiLUAlgorithm(Enum): + """RMSNormUpGateSiLU algorithm""" + + FP8MMA = "fp8mma" + FP16MMA = "fp16mma" + + +RMSNormUpGateSiLUWeightsConverter = ExpertSelectUpGateSiLUWeightsConverter +ExpertSelectUpGateSiLUW = ExpertSelectUpGateSiLUWeightsConverter + + +@dataclass +class RMSNormUpGateSiLUTilertWeightsAlias: + """TileRT weights alias for RMSNormUpGateSiLU.""" + + unproj_o_gamma = "unproj_o_gamma" + gate_weights = "gate_weights" + gate_scales = "gate_scales" + up_weights = "up_weights" + up_scales = "up_scales" + + @property + def tilert_tensor_alias(self) -> list[str]: + return [ + self.unproj_o_gamma, + self.gate_weights, + self.gate_scales, + self.up_weights, + self.up_scales, + ] + + def __call__(self) -> list[str]: + return self.tilert_tensor_alias + + +class RMSNormUpGateSiLU(TileRTModule): + """RMSNormUpGateSiLU module""" + + def __init__( + self, + model_args: ModelArgs, + device_id: int, + num_devices: int, + algorithm: RMSNormUpGateSiLUAlgorithm = RMSNormUpGateSiLUAlgorithm.FP8MMA, + ): + super().__init__( + self.__class__.__name__, + model_args=model_args, + device_id=device_id, + num_devices=num_devices, + ) + + self.arch_name = self.model_args.arch_name + self.dim = self.model_args.dim + + self.inter_dim = self.model_args.inter_dim + self.moe_inter_dim = self.model_args.moe_inter_dim + self.moe_inter_dim_per_device = self.moe_inter_dim // self.num_devices + self.inter_dim_per_device = self.inter_dim // self.num_devices + # effective number of experts + self.n_experts = self.inter_dim_per_device // self.moe_inter_dim_per_device + self.eps = self.model_args.eps + + self.block_size = self.model_args.block_size + self.algorithm = algorithm + + # reference weights + self.ref_norm_gamma: torch.Tensor | None = None + self.ref_gate: torch.Tensor | None = None + self.ref_up: torch.Tensor | None = None + + # tilert weights + self.tilert_norm_gamma: torch.Tensor | None = None + self.tilert_weights: torch.Tensor | None = None + # for compatibility, to be removed in the future + self.tilert_scales = torch.zeros( + 9, 4, 64, dtype=torch.bfloat16, device=torch.device("cuda") + ) + + # tilert vars + self.hidden_out: torch.Tensor | None = None + + self.profile_logs: torch.Tensor | None = None + self.is_init = False + + # tilert_funcs + self.rmsnorm_up_gate_silu_func = rmsnorm_up_gate_silu + + self.tilert_weights_alias = RMSNormUpGateSiLUTilertWeightsAlias() + + # reference tensor aliases + self.ref_tensor_alias: list[str] = [ + "post_attention_layernorm.weight", + "mlp.gate_proj.weight", + "mlp.gate_proj.weight_scale_inv", + "mlp.up_proj.weight", + "mlp.up_proj.weight_scale_inv", + ] + + @property + def tilert_tensor_alias(self) -> list[str]: + return self.tilert_weights_alias() + + def get_weights_list(self) -> list[torch.Tensor]: + """ + Get the weights list. + + Returns: + List of weights. + """ + return [self.tilert_norm_gamma, self.tilert_weights, self.tilert_scales] + + def device_sharding( + self, + weights_dict: dict[str, torch.Tensor], + key_prefix: str, # e.g. model.layers.{layer_id}.mlp + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Device sharding. + + Args: + weights_dict: Dictionary of weights. + + Returns: + Tuple of weights. + """ + rmsnorm_gamma_key = f"{key_prefix}.post_attention_layernorm.weight" + if ".mlp" in key_prefix: + key_prefix_without_mlp = key_prefix.replace(".mlp", "") + rmsnorm_gamma_key = f"{key_prefix_without_mlp}.post_attention_layernorm.weight" + elif key_prefix == "mlp": + rmsnorm_gamma_key = "post_attention_layernorm.weight" + rmsnorm_gamma = weights_dict[rmsnorm_gamma_key] + # repeat rmsnorm_gamma for each device + rmsnorm_gamma = rmsnorm_gamma[None, :].repeat(self.num_devices, 1) + + gate_weights, gate_scales, up_weights, up_scales = ( + ExpertSelectUpGateSiLU.process_gate_up_weights( + key_prefix, + weights_dict, + self.num_devices, + ) + ) + # Transpose split so to match the old convertcode + gate_weights = gate_weights.reshape(self.n_experts, self.num_devices, -1, self.dim) + gate_weights = gate_weights.transpose(0, 1) + gate_scales = gate_scales.reshape( + self.n_experts, self.num_devices, -1, self.dim // self.block_size + ) + gate_scales = gate_scales.transpose(0, 1) + up_weights = up_weights.reshape(self.n_experts, self.num_devices, -1, self.dim) + up_weights = up_weights.transpose(0, 1) + up_scales = up_scales.reshape( + self.n_experts, self.num_devices, -1, self.dim // self.block_size + ) + up_scales = up_scales.transpose(0, 1) + return ( + rmsnorm_gamma.contiguous(), + gate_weights.contiguous(), + gate_scales.contiguous(), + up_weights.contiguous(), + up_scales.contiguous(), + ) + + def init_reference_weights( + self, + state_dict: dict[str, torch.Tensor], + key_prefix: str, # e.g. model.layers.{layer_id}.mlp + device_id: int = 0, + ) -> None: + """ + Initialize the reference weights. + + Args: + state_dict: State dictionary. + device_id: Device ID. + """ + sharded_list = self.device_sharding(state_dict, key_prefix) + + gamma = sharded_list[0][device_id] + gate_weights = sharded_list[1][device_id] + gate_scales = sharded_list[2][device_id] + up_weights = sharded_list[3][device_id] + up_scales = sharded_list[4][device_id] + self.ref_norm_gamma = gamma + ref_gate_list = [ + weight_dequant(gate_weights, gate_scales) + for gate_weights, gate_scales in zip(gate_weights, gate_scales) + ] + ref_up_list = [ + weight_dequant(up_weights, up_scales) + for up_weights, up_scales in zip(up_weights, up_scales) + ] + self.ref_gate = torch.stack(ref_gate_list, dim=0) + self.ref_up = torch.stack(ref_up_list, dim=0) + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """ + Initialize the tilert weights. + + Args: + state_dict: State dictionary. + """ + assert self.algorithm is not None, "Algorithm is not set" + self.tilert_norm_gamma, self.tilert_weights = RMSNormUpGateSiLUWeightsConverter( + self.model_args, self.num_devices + ).dispatch(self.algorithm, [state_dict[alias] for alias in self.tilert_weights_alias()]) + + def init_tilert_vars(self, batch_size: int, seq_len: int, dev_id: int = 0) -> None: + """ + Initialize the tilert variables. + + Args: + batch_size: Batch size. + seq_len: Sequence length. + """ + # tilert vars + self.hidden_out = torch.zeros( + ( + batch_size, + seq_len, + self.n_experts, + self.moe_inter_dim_per_device, + ), + dtype=torch.bfloat16, + device=f"cuda:{dev_id}", + ) + + self.profile_logs = get_profile_log_tensor(device=f"cuda:{dev_id}") + self.is_init = True + + def init_random_weights(self, dev_id: int = 0) -> None: + """ + Initialize the random weights. + + Returns: + None + """ + gamma = torch.randn(self.dim, dtype=torch.float32, device=f"cuda:{dev_id}") + gate_weights = torch.randn( + self.inter_dim, self.dim, dtype=torch.bfloat16, device=f"cuda:{dev_id}" + ).to(torch.float8_e4m3fn) + up_weights = torch.randn( + self.inter_dim, self.dim, dtype=torch.bfloat16, device=f"cuda:{dev_id}" + ).to(torch.float8_e4m3fn) + inter_dim_scale_dim = self.inter_dim // self.block_size + dim_scale_dim = self.dim // self.block_size + scale_dtype = torch.float32 if self.arch_name == "glm_5" else torch.bfloat16 + gate_scales = torch.randn( + inter_dim_scale_dim, dim_scale_dim, dtype=scale_dtype, device=f"cuda:{dev_id}" + ) + up_scales = torch.randn( + inter_dim_scale_dim, dim_scale_dim, dtype=scale_dtype, device=f"cuda:{dev_id}" + ) + tensor_list = [ + gamma, + gate_weights, + gate_scales, + up_weights, + up_scales, + ] + state_dict = dict(zip(self.ref_tensor_alias, tensor_list)) + self.init_reference_weights(state_dict, "mlp", dev_id) + sharded_list = self.device_sharding(state_dict, "mlp") + sharded_state_dict = { + alias: sharded_list[i][dev_id] for i, alias in enumerate(self.tilert_weights_alias()) + } + self.init_tilert_weights(sharded_state_dict) + + def golden_forward( + self, + x_in: torch.Tensor, + ) -> torch.Tensor: + assert self.ref_gate is not None + assert self.ref_up is not None + bsz = x_in.shape[0] + seq_len = x_in.shape[1] + assert bsz == 1 + x_in_rmsnorm = torch.nn.functional.rms_norm( + x_in.float(), [x_in.size(-1)], self.ref_norm_gamma, self.eps + ) + hidden_out_list = [] + for s in range(seq_len): + # ref up-gate silu + hidden_out_w1_list = [] + hidden_out_w3_list = [] + + for i in range(self.n_experts): + hidden_out_w1_sel = x_in_rmsnorm[0, s].float() @ self.ref_gate[i].float().T + hidden_out_w3_sel = x_in_rmsnorm[0, s].float() @ self.ref_up[i].float().T + hidden_out_w1_list.append(hidden_out_w1_sel) + hidden_out_w3_list.append(hidden_out_w3_sel) + hidden_out_w1 = torch.stack(hidden_out_w1_list, dim=0) + hidden_out_w3 = torch.stack(hidden_out_w3_list, dim=0) + hidden_out = F.silu(hidden_out_w1.float()) * hidden_out_w3.float() + hidden_out = hidden_out.to(torch.bfloat16) + hidden_out_list.append(hidden_out) + hidden_out = torch.stack(hidden_out_list, dim=0) + hidden_out = hidden_out[None, ...] + return hidden_out + + def tilert_forward( + self, + x_in: torch.Tensor, + ) -> torch.Tensor: + assert self.rmsnorm_up_gate_silu_func is not None + assert self.algorithm is not None, "Algorithm is not set" + self.rmsnorm_up_gate_silu_func( + x_in, + self.tilert_norm_gamma, + self.tilert_weights, + self.hidden_out, + self.profile_logs, + self.algorithm.value, + ) + if self.flag_enable_profiling_log: + parse_profile_log_tensor( + self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)] + ) + return self.hidden_out + + def __call__( + self, + x_in: torch.Tensor, + ) -> torch.Tensor: + return self.golden_forward(x_in) diff --git a/python/models/deepseek_v3_2/ops/rotate.py b/python/models/deepseek_v3_2/ops/rotate.py new file mode 100644 index 0000000..539f334 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/rotate.py @@ -0,0 +1,210 @@ +from dataclasses import dataclass + +import torch +import torch.nn.functional as F + +from tilert.models.base import TileRTModule +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.models.utils import apply_rotary_emb +from tilert.profiler.utils import parse_profile_log_tensor +from tilert.utils import get_profile_log_tensor + +try: + from fast_hadamard_transform import hadamard_transform + + def rotate_activation(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + hidden_size = x.size(-1) + return hadamard_transform(x, scale=hidden_size**-0.5) + +except ImportError: + print( + "Cannot import hadamard_transform, fallback to scipy.linalg.hadamard." + "please install fast_hadamard_transform for correct performance." + ) + import math + + from scipy.linalg import hadamard + + def hadamard_transform_ref(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + x_shape = x.shape + dim = x.shape[-1] + x = x.reshape(-1, dim) + log_dim = math.ceil(math.log2(dim)) + dim_padded = 2**log_dim + if dim != dim_padded: + x = F.pad(x, (0, dim_padded - dim)) + out = F.linear( + x, + torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device), + ) + out = out * scale + return out[..., :dim].reshape(*x_shape) + + def rotate_activation(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + hidden_size = x.size(-1) + return hadamard_transform_ref(x, scale=hidden_size**-0.5) + + +__all__ = [ + "rotate", + "rotate_activation", + "Rotate", + "RotateRefWeightsAlias", + "RotateTilertWeightsAlias", +] + + +def rotate( + input_raw: torch.Tensor, + output_raw: torch.Tensor, + freqs_cis_raw: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """ + Rotate (hadamard transform) operation. + + Unified for deepseek_v3_2 (64 heads) and glm_5 (32 heads). Dispatches by + input_raw.shape[-2]: 64 -> rotate_op, 32 -> rotate_glm5_op. + + Args: + input_raw (torch.Tensor): The input tensor [..., head, 128]. + output_raw (torch.Tensor): The output tensor where the result will be stored. + freqs_cis_raw (torch.Tensor): The frequency tensor. + profile_logs (torch.Tensor): Tensor for storing profiling logs. + + Returns: + None + """ + if input_raw.dtype != torch.bfloat16: + raise ValueError("input must be a bfloat16 tensor.") + + if output_raw.dtype != torch.bfloat16: + raise ValueError("output must be a bfloat16 tensor.") + + if freqs_cis_raw.dtype != torch.float32: + raise ValueError("freqs_cis must be a float32 tensor.") + + head = input_raw.shape[-2] + dim = input_raw.shape[-1] + if dim != 128: + raise ValueError("dim must be 128, as we precompute scale inner kernel") + + if head == 64: + torch.ops.tilert.rotate_op(input_raw, output_raw, freqs_cis_raw, profile_logs) + elif head == 32: + torch.ops.tilert.rotate_glm5_op(input_raw, output_raw, freqs_cis_raw, profile_logs) + else: + raise ValueError( + f"Unsupported head size: {head}. Rotate op supports " + "index_n_heads=64 (deepseek_v3_2) or 32 (glm_5)." + ) + + +@dataclass +class RotateRefWeightsAlias: + """Reference weights alias for Rotate (no weights).""" + + @property + def ref_tensor_alias(self) -> list[str]: + return [] + + def __call__(self) -> list[str]: + return self.ref_tensor_alias + + +@dataclass +class RotateTilertWeightsAlias: + """TileRT weights alias for Rotate (no weights).""" + + @property + def tilert_tensor_alias(self) -> list[str]: + return [] + + def __call__(self) -> list[str]: + return self.tilert_tensor_alias + + +class Rotate(TileRTModule): + """Rotate module: RoPE on first qk_rope_head_dim dims + hadamard transform. + + Unified for deepseek_v3_2 (index_n_heads=64) and glm_5 (index_n_heads=32). + No weights; uses model_args for dimensions. + """ + + def __init__( + self, + model_args: ModelArgs, + num_devices: int = 1, + device_id: int = 0, + ref_weights_alias: RotateRefWeightsAlias | None = None, + ): + super().__init__( + self.__class__.__name__, + model_args=model_args, + num_devices=num_devices, + device_id=device_id, + ) + self.tilert_weights_alias = RotateTilertWeightsAlias() + self.ref_weights_alias = ( + ref_weights_alias if ref_weights_alias is not None else RotateRefWeightsAlias() + ) + + self.qk_rope_head_dim = model_args.qk_rope_head_dim + self.index_n_heads = model_args.index_n_heads + self.index_head_dim = model_args.index_head_dim + + self.output: torch.Tensor | None = None + self.profile_logs: torch.Tensor | None = None + + def get_weights_list(self) -> list[torch.Tensor]: + return [] + + def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + del weights_map + return {} + + def init_reference_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + del state_dict + pass + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + del state_dict + pass + + def init_random_weights(self) -> None: + pass + + def init_tilert_vars(self, batch_size: int, seq_len: int) -> None: + self.output = torch.zeros( + (batch_size, seq_len, self.index_n_heads, self.index_head_dim), + dtype=torch.bfloat16, + ) + self.profile_logs = get_profile_log_tensor() + self.is_init = True + + def golden_forward( + self, + idx_q: torch.Tensor, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + q_pe_idx, q_nope_idx = torch.split( + idx_q, + [self.qk_rope_head_dim, self.index_head_dim - self.qk_rope_head_dim], + dim=-1, + ) + q_pe_idx = apply_rotary_emb(q_pe_idx, freqs_cis) + idx_q = torch.cat([q_pe_idx, q_nope_idx], dim=-1) + return rotate_activation(idx_q) + + def tilert_forward(self, idx_q: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + assert self.output is not None + assert self.profile_logs is not None + freqs_cis_real = torch.view_as_real(freqs_cis).reshape(*freqs_cis.shape[:-1], -1) + rotate(idx_q, self.output, freqs_cis_real, self.profile_logs) + if self.flag_enable_profiling_log: + parse_profile_log_tensor( + self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)] + ) + return self.output diff --git a/python/models/deepseek_v3_2/ops/sparse_index.py b/python/models/deepseek_v3_2/ops/sparse_index.py new file mode 100644 index 0000000..0c21ce8 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/sparse_index.py @@ -0,0 +1,124 @@ +"""Sparse index operation module.""" + +import torch + +__all__ = [ + "sparse_index", + "sparse_index_topk", +] + + +def sparse_index( + q: torch.Tensor, # noqa: VNE001 + kv: torch.Tensor, + weights: torch.Tensor, + logits: torch.Tensor, + cur_pos: int, + profile_logs: torch.Tensor, +) -> None: + """ + Sparse index operation. + + Calculate sparse index using q * kv * weights. + + Args: + q (torch.Tensor): The query tensor. + kv (torch.Tensor): The key-value tensor. + weights (torch.Tensor): The weights tensor. + logits (torch.Tensor): The logits tensor. + cur_pos (int): The position of the first token. + profile_logs (torch.Tensor): Tensor for storing profiling logs. + + Returns: + None + """ + if q.dtype != torch.bfloat16: + raise ValueError("input must be a bfloat16 tensor.") + if kv.dtype != torch.bfloat16: + raise ValueError("kv must be a bfloat16 tensor.") + if weights.dtype != torch.bfloat16: + raise ValueError("weights must be a bfloat16 tensor.") + if logits.dtype != torch.float32: + raise ValueError("logits must be a float32 tensor.") + + head = q.shape[-2] + dim = q.shape[-1] + + if head != 64 and head != 32: + raise ValueError( + f"Unsupported head size: {head}. Sparse index op currently only \ + supports a head number of 64 or 32." + ) + if dim != 128: + raise ValueError("dim must be 128, as we precompute scale inner kernel") + + device = q.device + if any(t.device != device for t in (kv, weights, logits, profile_logs)): + raise ValueError( + "sparse_index inputs must be on the same device: " + f"q={device}, kv={kv.device}, weights={weights.device}, " + f"logits={logits.device}, profile_logs={profile_logs.device}" + ) + if head == 64: + torch.ops.tilert.sparse_index_op(q, kv, weights, logits, cur_pos, profile_logs) + elif head == 32: + torch.ops.tilert.sparse_index_glm5_op(q, kv, weights, logits, cur_pos, profile_logs) + + +def sparse_index_topk( + q: torch.Tensor, # noqa: VNE001 + kv: torch.Tensor, + weights: torch.Tensor, + logits: torch.Tensor, + indices: torch.Tensor, + cur_pos: int, + profile_logs: torch.Tensor, +) -> None: + """ + Sparse index operation. + + Calculate sparse index using q * kv * weights. + + Args: + q (torch.Tensor): The query tensor. + kv (torch.Tensor): The key-value tensor. + weights (torch.Tensor): The weights tensor. + logits (torch.Tensor): The logits tensor. + cur_pos (int): The position of the first token. + profile_logs (torch.Tensor): Tensor for storing profiling logs. + + Returns: + None + """ + if q.dtype != torch.bfloat16: + raise ValueError("input must be a bfloat16 tensor.") + if kv.dtype != torch.bfloat16: + raise ValueError("kv must be a bfloat16 tensor.") + if weights.dtype != torch.bfloat16: + raise ValueError("weights must be a bfloat16 tensor.") + if logits.dtype != torch.float32: + raise ValueError("logits must be a float32 tensor.") + + seqlen = q.shape[-3] + head = q.shape[-2] + dim = q.shape[-1] + + if head != 32: + raise ValueError( + f"Unsupported head size: {head}. Sparse index topk fused op currently only \ + supports a head number of 32." + ) + if dim != 128: + raise ValueError("dim must be 128, as we precompute scale inner kernel") + + device = q.device + if any(t.device != device for t in (kv, weights, logits, indices, profile_logs)): + raise ValueError( + "sparse_index inputs must be on the same device: " + f"q={device}, kv={kv.device}, weights={weights.device}, " + f"logits={logits.device}, profile_logs={profile_logs.device}" + ) + workspace = torch.zeros(seqlen, (200 * 1024 + 258), dtype=torch.int32, device=device) + torch.ops.tilert.sparse_index_topk_glm5_op( + q, kv, weights, logits, cur_pos, indices, workspace, profile_logs + ) diff --git a/python/models/deepseek_v3_2/ops/top1_allreduce.py b/python/models/deepseek_v3_2/ops/top1_allreduce.py new file mode 100644 index 0000000..1d500e3 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/top1_allreduce.py @@ -0,0 +1,25 @@ +"""Top1 Allreduce operation""" + +import torch + +__all__ = [ + "top1_allreduce", +] + + +def top1_allreduce( + logits: torch.Tensor, + flag: int, + index_out: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """ + Define the Top1 Allreduce operation. + + Args: + logits: Input tensor. + flag: Flag. + index_out: Output tensor. + profile_logs: Profile logs tensor. + """ + torch.ops.tilert.top1_allreduce_op(logits, flag, index_out, profile_logs) diff --git a/python/models/deepseek_v3_2/ops/top_p.py b/python/models/deepseek_v3_2/ops/top_p.py new file mode 100644 index 0000000..4394c2a --- /dev/null +++ b/python/models/deepseek_v3_2/ops/top_p.py @@ -0,0 +1,68 @@ +"""TopP operation module.""" + +import torch + +__all__ = [ + "top_p", +] + + +def top_p( + logits: torch.Tensor, + in_indices: torch.Tensor, + sampling_seed: torch.Tensor, + positions: torch.Tensor, + is_verify_mode: bool, + temperature: float, + top_p: float, + top_k: int, + flag: int, + indices: torch.Tensor, + scores: torch.Tensor, + debug_tensor: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """top_p operation. + + Args: + logits (Tensor): The logits tensor. + in_indices (Tensor): The tensor containing input indices. + sampling_seed (Tensor): Random seeds for each sequence position. + positions (Tensor): Token positions for each sequence element. + is_verify_mode (bool): A flag indicating if verify mode is enabled in MTP. When set to + `True`, the `in_indices` will be checked to check if it is in + the top-k values. + temperature (float): The temperature parameter, used for scaling logits in softmax + calculations. + top_p (float): The top-p value, used for nucleus sampling to restrict the selection to the + smallest set of tokens whose cumulative probability is greater than or equal + to `top_p`. + top_k (int): The number of top-k values that occupy the top-p probability mass + during sampling. + flag (int): Used in all reduction. + indices (Tensor): The tensor containing output indices. + scores (Tensor): The tensor containing corresponding scores for the indices. + profile_logs (Tensor): A tensor for storing profiling log data during execution in MTP. + """ + dim = logits.shape[-1] + if dim == 19360: + call_func = torch.ops.tilert.top_p_glm5_op + elif dim == 16160: + call_func = torch.ops.tilert.top_p_op + else: + raise ValueError(f"Unsupported dimension: {dim}") + call_func( + logits, + in_indices, + sampling_seed, + positions, + is_verify_mode, + temperature, + top_p, + top_k, + flag, + indices, + scores, + debug_tensor, + profile_logs, + ) diff --git a/python/models/deepseek_v3_2/ops/topk.py b/python/models/deepseek_v3_2/ops/topk.py new file mode 100644 index 0000000..bb41575 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/topk.py @@ -0,0 +1,139 @@ +"""topk operations module.""" + +import torch +import torch.nn as nn + +from tilert.utils import get_profile_log_tensor + +__all__ = [ + "TopK", + "topk_approximate", + "topk_accurate", +] + + +def topk_approximate( + logits: torch.Tensor, + seq_len: int, + topk: int, + profile_logs: torch.Tensor, +) -> torch.Tensor: + """ + Topk approximate operation. + + Topk approximate the input tensor `logits` and stores the result in `output_raw`. + + Args: + logits (torch.Tensor): The input tensor. + seq_len (int): valid data of logits.shape[-1] + topk (int): The number of topk to approximate. + profile_logs (torch.Tensor): The profile logs tensor. + + Returns: + indices (torch.Tensor): The output tensor. + """ + if logits.dtype != torch.float32: + raise ValueError("logits must be a float32 tensor.") + + if topk != 2048: + raise ValueError("topk must be 2048.") + batch = logits.shape[0] + if batch != 1: + raise ValueError("batch must be 1 in this version") + + indices = torch.zeros(batch, topk, dtype=torch.int32, device=logits.device) + torch.ops.tilert.topk_approximate_op(logits, indices, seq_len, profile_logs) + + return indices + + +def topk_accurate( + logits: torch.Tensor, + seq_len: int, + topk: int, + profile_logs: torch.Tensor, +) -> torch.Tensor: + """ + Topk approximate operation. + + Topk approximate the input tensor `logits` and stores the result in `output_raw`. + + Args: + logits (torch.Tensor): The input tensor. + seq_len (int): length of last samples, + for k=logits.shape[1] samples, the length is + seq-k+1, seq-k+2, ..., seq-1, seq + topk (int): The number of topk to approximate. + profile_logs (torch.Tensor): The profile logs tensor. + Returns: + indices (torch.Tensor): The output tensor. + """ + if logits.dtype != torch.float32: + raise ValueError("logits must be a float32 tensor.") + + if topk != 2048: + raise ValueError("topk must be 2048.") + + assert logits.shape[0] == 1, "batch must be 1 in this version" + num_samples = logits.shape[1] + + indices = torch.zeros(num_samples, topk, dtype=torch.int32, device=logits.device) + indices_ws = torch.zeros(1, num_samples, 4, topk * 2, dtype=torch.int32, device=logits.device) + torch.ops.tilert.topk_accurate_op( + logits, indices, seq_len - num_samples, indices_ws, profile_logs + ) + + return indices + + +class TopK(nn.Module): + """TopK operation with optional approximate kernel. + + Wraps topk_accurate / topk_approximate and provides golden_forward + (reference implementation) and tilert_forward (TileRT kernel). + """ + + def __init__(self, use_approximate: bool = False) -> None: + super().__init__() + self.use_approximate = use_approximate + + def golden_forward( + self, + logits: torch.Tensor, + topk: int, + ) -> torch.Tensor: + """Reference forward: torch.topk on the last dimension. + + Args: + logits: Scores tensor, shape (batch, ..., seq_len). + topk: Number of top indices to return. + + Returns: + Indices of top-k values along the last dimension. + """ + seq_len = logits.shape[-1] + return logits.topk(min(topk, seq_len), dim=-1)[1] + + def tilert_forward( + self, + logits: torch.Tensor, + topk: int, + ) -> torch.Tensor: + """Tilert forward: batch of samples with varying valid length. + + Args: + logits: Shape (batch, num_samples, cache_len). + topk: Number of top indices to return. + + Returns: + Indices tensor of shape (batch, num_samples, topk). + """ + profile_logs = get_profile_log_tensor(device=logits.device) + cache_len = logits.shape[-1] + if self.use_approximate: + indices = topk_approximate(logits, cache_len, topk, profile_logs) + else: + indices = topk_accurate(logits, cache_len, topk, profile_logs) + if indices.dim() == 2: + return indices.unsqueeze(0) + return indices diff --git a/python/models/deepseek_v3_2/ops/unproj_o_allreduce.py b/python/models/deepseek_v3_2/ops/unproj_o_allreduce.py new file mode 100644 index 0000000..50b413f --- /dev/null +++ b/python/models/deepseek_v3_2/ops/unproj_o_allreduce.py @@ -0,0 +1,428 @@ +"""UnprojOAllreduce operation module.""" + +from dataclasses import dataclass +from enum import Enum + +import torch + +from tilert.models.base import TileRTModule, TilertWeightsConverter +from tilert.models.common import weight_dequant +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.profiler.utils import parse_profile_log_tensor +from tilert.utils import get_profile_log_tensor + +__all__ = [ + "unproj_o_allreduce", + "UnProjOAllReduce", + "UnProjOAllReduceRefWeightsAlias", + "UnProjOAllReduceTilertWeightsAlias", +] + + +def unproj_o_allreduce( + vec_in: torch.Tensor, + mat_in: torch.Tensor, + mat_scale: torch.Tensor, + x_in: torch.Tensor, + flag: int, + vec_out: torch.Tensor, + profile_logs: torch.Tensor, + algorithm: str = "fp8mma", +) -> None: + """ + Fused operation of unprojection and allreduce. + + Args: + vec_in: Input tensor. + mat_in: Input tensor. + mat_scale: Input tensor. + x_in: Input tensor. + flag: Input flag. + vec_out: Output tensor. + profile_logs: Profile logs tensor. This is a 1D tensor of shape + (num_sms,) to store the profile logs of the unproj_o_allreduce + operation, where num_sms is the number of SMs on the + device. + """ + if vec_out.shape[-1] == 7168: + assert algorithm == "fp8mma", "Only fp8mma is supported for deepseek_v3_2" + torch.ops.tilert.unproj_o_allreduce_op( + vec_in, mat_in, mat_scale, x_in, flag, vec_out, profile_logs + ) + + elif vec_out.shape[-1] == 6144: + torch.ops.tilert.unproj_o_allreduce_glm5_op( + vec_in, mat_in, mat_scale, x_in, flag, vec_out, profile_logs, algorithm + ) + else: + raise ValueError(f"Unsupported vector dimension: {vec_out.shape[-1]}") + + +class UnProjOAllReduceAlgorithm(Enum): + """UnprojOAllReduce algorithm""" + + FP8MMA = "fp8mma" + FP16MMA = "fp16mma" + + +@dataclass +class UnProjOAllReduceRefWeightsAlias: + """Reference weights alias for UnProjOAllReduce.""" + + o_proj_weight = "self_attn.o_proj.weight" + o_proj_scale_inv = "self_attn.o_proj.weight_scale_inv" + + @property + def ref_tensor_alias(self) -> list[str]: + return [self.o_proj_weight, self.o_proj_scale_inv] + + def __call__(self) -> list[str]: + return self.ref_tensor_alias + + +@dataclass +class UnProjOAllReduceTilertWeightsAlias: + """TileRT weights alias for UnProjOAllReduce.""" + + unproj_weights = "unproj_weights" + unproj_scales = "unproj_scales" + + @property + def tilert_tensor_alias(self) -> list[str]: + return [self.unproj_weights, self.unproj_scales] + + def __call__(self) -> list[str]: + return self.tilert_tensor_alias + + +class UnProjOAllReduceWeightsConverter(TilertWeightsConverter): + """UnProjOAllReduce weights converter""" + + @staticmethod + def _swizzle_qmma_16x32(mat_in: torch.Tensor) -> torch.Tensor: + assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32 + assert mat_in.dtype == torch.float8_e4m3fn + # PTX isa fig.88 + pre_shape = mat_in.shape[:-2] + mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4) + return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2) + + def convert_to_fp8mma( + self, weights_list: list[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Convert the weights to fp8mma format. + + Args: + weights_list: List of weights. + + Returns: + Tuple of weights. + """ + args = self.model_args + assert args.arch_name == "deepseek_v3_2" or args.arch_name == "glm_5" + arch_name = args.arch_name + dim = args.dim + num_sms = 128 + if arch_name == "deepseek_v3_2": + num_sms = 112 + dim_per_sm = dim // num_sms + dim_scale_dim = dim // args.block_size + + with torch.inference_mode(): + mat_in, scales_trt = weights_list + vec_dim = mat_in.shape[-1] # 2048 for both deepseek_v3_2 and glm_5 + assert scales_trt.shape == (dim // args.block_size, vec_dim // args.block_size) + + weights_trt = mat_in.reshape(num_sms, dim_per_sm, vec_dim) + # dim_per_stage is 512 + stages = vec_dim // 512 + weights_trt = weights_trt.reshape(num_sms, dim_per_sm, stages, 512).transpose(1, 2) + + weights_trt = weights_trt.reshape( + num_sms, stages, dim_per_sm // 16, 16, 16, 32 + ).transpose(-2, -3) + weights_trt = self._swizzle_qmma_16x32(weights_trt) + weights_trt = weights_trt.reshape(num_sms, stages, -1) + + if arch_name == "glm_5": + if scales_trt.dtype != torch.float32: + print( + "Warning: UnProjOAllReduceWeightsConverter: " + + f"scales_trt.dtype: {scales_trt.dtype} " + + "is not float32, convert to float32." + ) + scales_trt = scales_trt.to(torch.float32) + # repeat 8 times + scales_trt = ( + scales_trt.reshape((dim_scale_dim, 1, -1)).repeat(1, 8, 1).reshape(num_sms, -1) + ) + else: # DS v3.2, use bfloat16 for scales + scales_trt = scales_trt.to(torch.bfloat16) + + return weights_trt.contiguous(), scales_trt.contiguous() + + @staticmethod + def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor: + assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16 + # PTX isa fig.88 + pre_shape = mat_in.shape[:-2] + mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4) + return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2) + + def convert_to_fp16mma( + self, + weights_list: list[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """Convert common weights to TileRT FP16 MMA layout.""" + assert self.model_args.arch_name == "glm_5", "Only GLM-5 supports FP16 MMA" + + with torch.inference_mode(): + mat, scales = weights_list + if scales.dtype != torch.float32: + print( + "Warning: UnProjOAllReduceWeightsConverter: " + + f"scales.dtype: {scales.dtype} " + + "is not float32, convert to float32." + ) + scales = scales.to(torch.float32) + + sms = 128 # use 128 sms for glm_5 + pages = 4 + scales = scales.reshape(6144 // 128, 1, 2048 // 128) + scales = scales.repeat(1, 8, 1) + scales = scales.reshape(128, 3, 4, 4).transpose(1, 2) + # to 128, 4, 12x4 + scales = scales.reshape(128, 4, 12).view(torch.float8_e4m3fn) + + mat = ( + mat.reshape(128, 48, 2048) + .reshape(128, 3, 16, 4, 512) + .transpose(2, 3) + .reshape(128, 3, 4, 16, 32, 16) + .transpose(3, 4) + .reshape(128, 3, 4, 32, 16, 16) + ) + mat = UnProjOAllReduceWeightsConverter._swizzle_mma_16x16(mat) + mat = mat.transpose(1, 2).reshape(128, 4, -1) + + scales_padding = torch.zeros( + sms, + pages, + 128 - scales.shape[-1], + dtype=torch.float8_e4m3fn, + device=mat.device, + ) # append 128-byte aligned scale: (128, 4, 24704) for glm_5 + mat_full = torch.cat([mat, scales, scales_padding], dim=-1).contiguous() + dummy_scales = torch.zeros(1, dtype=torch.float32, device=mat.device) + return mat_full, dummy_scales + + +class UnProjOAllReduce(TileRTModule): + """UnProjOAllReduce module""" + + def __init__( + self, + model_args: ModelArgs, + num_devices: int, + device_id: int = 0, + ref_weights_alias: UnProjOAllReduceRefWeightsAlias | None = None, + tilert_weights_alias: UnProjOAllReduceTilertWeightsAlias | None = None, + algorithm: UnProjOAllReduceAlgorithm = UnProjOAllReduceAlgorithm.FP8MMA, + ): + super().__init__( + self.__class__.__name__, + model_args=model_args, + num_devices=num_devices, + device_id=device_id, + ) + + self.tilert_weights_alias = ( + tilert_weights_alias + if tilert_weights_alias is not None + else UnProjOAllReduceTilertWeightsAlias() + ) + self.ref_weights_alias = ( + ref_weights_alias + if ref_weights_alias is not None + else UnProjOAllReduceRefWeightsAlias() + ) + + self.arch_name = self.model_args.arch_name + self.dim = self.model_args.dim + self.n_heads = self.model_args.n_heads + self.head_dim = self.model_args.v_head_dim + + self.block_size = self.model_args.block_size + self.algorithm: UnProjOAllReduceAlgorithm = algorithm + + # reference weights + self.ref_unproj_o: torch.Tensor | None = None + + # tilert weights + self.tilert_weights: torch.Tensor | None = None + self.tilert_scales: torch.Tensor | None = None + + # tilert vars + self.hidden_out: torch.Tensor | None = None + + self.profile_logs: torch.Tensor | None = None + self.is_var_init = False + + def get_weights_list(self) -> list[torch.Tensor]: + """ + Get the weights list. + + Returns: + List of weights. + """ + return [self.tilert_weights, self.tilert_scales] + + def device_sharding(self, weights_map: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Device sharding. + + Args: + weights_map: Map from ref weight alias to tensor (full model). + + Returns: + Map from tilert weight alias to (num_devices, ...) tensors. + """ + unproj_o_weight = weights_map[self.ref_weights_alias.o_proj_weight] + unproj_o_scale = weights_map[self.ref_weights_alias.o_proj_scale_inv] + unproj_o_weight = unproj_o_weight.reshape(self.dim, self.num_devices, -1) + unproj_o_weight = unproj_o_weight.transpose(0, 1) + unproj_o_scale = unproj_o_scale.reshape(self.dim // self.block_size, self.num_devices, -1) + unproj_o_scale = unproj_o_scale.transpose(0, 1) + return { + self.tilert_weights_alias.unproj_weights: unproj_o_weight.contiguous(), + self.tilert_weights_alias.unproj_scales: unproj_o_scale.contiguous(), + } + + def init_reference_weights( + self, + state_dict: dict[str, torch.Tensor], + device_id: int | None = None, + ) -> None: + """ + Initialize the reference weights. + + Args: + state_dict: State dictionary keyed by ref weight alias (full model). + device_id: Device ID for this shard; defaults to self.device_id. + """ + did = self.device_id if device_id is None else device_id + sharded = self.device_sharding(state_dict) + weights = sharded[self.tilert_weights_alias.unproj_weights][did] + scales = sharded[self.tilert_weights_alias.unproj_scales][did] + self.ref_unproj_o = weight_dequant(weights, scales) + + def init_tilert_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """ + Initialize the tilert weights. + + Args: + state_dict: State dictionary keyed by tilert weight alias (per-device). + """ + assert self.algorithm is not None, "Algorithm is not set" + self.tilert_weights, self.tilert_scales = UnProjOAllReduceWeightsConverter( + self.model_args, self.num_devices + ).dispatch( + self.algorithm, + [state_dict[alias] for alias in self.tilert_weights_alias()], + ) + + def init_tilert_vars(self, batch_size: int, seq_len: int) -> None: + """ + Initialize the tilert variables. + + Args: + batch_size: Batch size. + seq_len: Sequence length. + """ + self.hidden_out = torch.zeros( + (batch_size, seq_len, self.dim), + dtype=torch.bfloat16, + device=f"cuda:{self.device_id}", + ) + self.profile_logs = get_profile_log_tensor(device=f"cuda:{self.device_id}") + self.is_var_init = True + + def init_random_weights(self) -> None: + """Initialize the random weights.""" + unproj_o_weights = torch.randn( + self.dim, + self.n_heads * self.head_dim, + dtype=torch.bfloat16, + device=f"cuda:{self.device_id}", + ).to(torch.float8_e4m3fn) + + head_scale_dim = self.head_dim // self.block_size + dim_scale_dim = self.dim // self.block_size + scale_dtype = torch.float32 if self.arch_name == "glm_5" else torch.bfloat16 + unproj_o_scales = torch.randn( + dim_scale_dim, + self.n_heads * head_scale_dim, + dtype=scale_dtype, + device=f"cuda:{self.device_id}", + ) + ref_state_dict = { + self.ref_weights_alias.o_proj_weight: unproj_o_weights, + self.ref_weights_alias.o_proj_scale_inv: unproj_o_scales, + } + + self.init_reference_weights(ref_state_dict) + sharded = self.device_sharding(ref_state_dict) + per_device_state = {k: v[self.device_id] for k, v in sharded.items()} + self.init_tilert_weights(per_device_state) + + def golden_forward( + self, + vec_in: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass for the down-project module. + + Args: + vec_in: Input vector. + + Returns: + Output tensor. + """ + assert self.ref_unproj_o is not None + bsz = vec_in.shape[0] + seq_len = vec_in.shape[1] + assert bsz == 1 + res = vec_in.reshape(bsz, seq_len, -1).float() @ self.ref_unproj_o.T.float() + return res.to(torch.bfloat16) + + def tilert_forward( + self, + vec_in: torch.Tensor, + x_in: torch.Tensor, + flag: int, + ) -> torch.Tensor: + assert self.hidden_out is not None + assert self.profile_logs is not None + assert self.algorithm is not None + unproj_o_allreduce( + vec_in, + self.tilert_weights, + self.tilert_scales, + x_in, + flag, + self.hidden_out, + self.profile_logs, + self.algorithm.value, + ) + if self.flag_enable_profiling_log: + parse_profile_log_tensor( + self.profile_logs, self.get_profile_log_path(), [(self.op_name, 0.0)] + ) + return self.hidden_out + + def __call__( + self, + vec_in: torch.Tensor, + ) -> torch.Tensor: + return self.golden_forward(vec_in) diff --git a/python/models/deepseek_v3_2/ops/up_gate_silu.py b/python/models/deepseek_v3_2/ops/up_gate_silu.py new file mode 100644 index 0000000..2f214c0 --- /dev/null +++ b/python/models/deepseek_v3_2/ops/up_gate_silu.py @@ -0,0 +1,24 @@ +"""UpGateSiLU operation module.""" + +import torch + +__all__ = [ + "up_gate_silu", +] + + +def up_gate_silu( + hidden_in: torch.Tensor, + expert_indices_in: torch.Tensor, + experts_weights_in: torch.Tensor, + hidden_out: torch.Tensor, + profile_logs: torch.Tensor, +) -> None: + """Up Gate SiLU operation.""" + torch.ops.tilert.up_gate_silu_op( + hidden_in, + expert_indices_in, + experts_weights_in, + hidden_out, + profile_logs, + ) diff --git a/python/models/deepseek_v3_2/params.py b/python/models/deepseek_v3_2/params.py deleted file mode 100644 index a3914f8..0000000 --- a/python/models/deepseek_v3_2/params.py +++ /dev/null @@ -1,942 +0,0 @@ -from abc import abstractmethod - -import torch - -from tilert.models.deepseek_v3_2.model_args import ModelArgs as ModelArgsV3_2 -from tilert.models.utils import SwizzleMode, gen_tensor_swizzle_map_1d, precompute_freqs_cis - -__all__ = [ - "IntermediateMapper", - "BaseParams", - "MlaParams", - "MLPParams", - "MoEParams", - "TempVars", - "DenseLayerParamsKeys", - "MoELayerParamsKeys", - "CacheVars", - "gen_down_allreduce_fp8_params", -] - - -DenseLayerParamsKeys = [ - # MLA params - "x_rmsnorm_gamma", # 0 - "qkv_wa_weights", # 1 - "qkv_wa_scales", # 2 - "k_weights", # 3 - "k_bias", # 4 - "q_rmsnorm_gamma", # 5 - "q_wb_weights", # 6 - "q_wb_scales", # 7 - "id_score_weights", # 8 - "wkv_b1_weights", # 9 - "wkv_b1_scales", # 10 - "kv_rmsnorm_gamma", # 11 - "wkv_b2_weights", # 12 - "wkv_b2_scales", # 13 - "unproj_weights", # 14 - "unproj_scales", # 15 - # MLP params - "unproj_o_gamma", # 16 - "upgate_weights", # 17 - "upgate_scales", # 18 - "down_weights", # 19 - "down_scales", # 20 -] - -MoELayerParamsKeys = [ - # MLA params - "x_rmsnorm_gamma", # 0 - "qkv_wa_weights", # 1 - "qkv_wa_scales", # 2 - "k_weights", # 3 - "k_bias", # 4 - "q_rmsnorm_gamma", # 5 - "q_wb_weights", # 6 - "q_wb_scales", # 7 - "id_score_weights", # 8 - "wkv_b1_weights", # 9 - "wkv_b1_scales", # 10 - "kv_rmsnorm_gamma", # 11 - "wkv_b2_weights", # 12 - "wkv_b2_scales", # 13 - "unproj_weights", # 14 - "unproj_scales", # 15 - # MoE params - "unproj_o_gamma", # 16 - "exp_proj_weights", # 17 - "exp_bias", # 18 - "exp_upgate_weights", # 19 - "exp_upgate_scales", # 20 - "exp_down_weights", # 21 - "exp_down_scales", # 22 -] - - -def gen_down_allreduce_fp8_params(mat_in: torch.Tensor, mat_scale_in: torch.Tensor) -> torch.Tensor: - """Convert tilert mat and scale to tilert-fp8 input format.""" - mat_and_scale_in = torch.zeros( - (9, 128, (56 * 256 + 64)), dtype=torch.float8_e4m3fn, device=mat_in.device - ) - scale_part = mat_and_scale_in[..., 56 * 256 :] - mat_part = mat_and_scale_in[..., : 56 * 256].reshape(9, 128, 2, 56 * 8, 16) - mat_in = mat_in.reshape(9, 128, 56, 2, 8, 16) - mat_in = mat_in.transpose(2, 3).reshape(9, 128, 2, 56 * 8, 16) - - swizzle_map = gen_tensor_swizzle_map_1d(56, 8, SwizzleMode.SWIZZLE_128B) - mat_part[:, :, :, swizzle_map] = mat_in - - # copy mat_scale_in to scale_part - # scale to fp32 - mat_scale_in_fp32 = mat_scale_in.to(torch.float32).reshape(9, 128, 16) # 7x2 + 2 zeros - scale_part.copy_(mat_scale_in_fp32.view(dtype=torch.float8_e4m3fn)) - return mat_and_scale_in - - -def gen_expert_down_allreduce_fp8_params( - mat_in: torch.Tensor, mat_scale_in: torch.Tensor -) -> torch.Tensor: - """Convert tilert mat and scale to tilert-fp8 input format.""" - mat_and_scale_in = torch.zeros( - (257, 128, (56 * 256 + 64)), dtype=torch.float8_e4m3fn, device=mat_in.device - ) - scale_part = mat_and_scale_in[..., 56 * 256 :] - mat_part = mat_and_scale_in[..., : 56 * 256].reshape(257, 128, 2, 56 * 8, 16) - mat_in = mat_in.reshape(257, 128, 56, 2, 8, 16) - mat_in = mat_in.transpose(2, 3).reshape(257, 128, 2, 56 * 8, 16) - - swizzle_map = gen_tensor_swizzle_map_1d(56, 8, SwizzleMode.SWIZZLE_128B) - mat_part[:, :, :, swizzle_map] = mat_in - - # copy mat_scale_in to scale_part - # scale to fp32 - mat_scale_in_fp32 = mat_scale_in.to(torch.float32).reshape(257, 128, 16) # 7x2 + 2 zeros - scale_part.copy_(mat_scale_in_fp32.view(dtype=torch.float8_e4m3fn)) - return mat_and_scale_in - - -def gen_unproj_o_allreduce_fp8_params( - mat_in: torch.Tensor, mat_scale_in: torch.Tensor -) -> torch.Tensor: - """Convert tilert mat and scale to tilert-fp8 input format.""" - mat_and_scale_in = torch.zeros( - (128, 4, (56 * 512 + 8 * 4 * 4)), dtype=torch.float8_e4m3fn, device=mat_in.device - ) - scale_part = mat_and_scale_in[..., 56 * 512 :].reshape(128, 4, 128) - mat_part = mat_and_scale_in[..., : 56 * 512].reshape(128, 4, 4, 56 * 8, 16) - - mat_in = mat_in.reshape(128, 56, 4, 512) - mat_in = ( - mat_in.transpose(1, 2) - .reshape(128, 4, 56, 4, 128) - .transpose(2, 3) - .reshape(128, 4, 4, 56 * 8, 16) - ) - - swizzle_map = gen_tensor_swizzle_map_1d(56, 8, SwizzleMode.SWIZZLE_128B) - mat_part[:, :, :, swizzle_map] = mat_in - # 896x16 - mat_scale_in_fp32 = mat_scale_in.to(torch.float32).reshape(128, 7, 16) - # padding to 1024x16 - zeros = torch.zeros((128, 1, 16), dtype=torch.float32, device=mat_scale_in.device) - mat_scale_in_fp32 = torch.cat([mat_scale_in_fp32, zeros], dim=1).contiguous() - # transpose - mat_scale_in_fp32 = mat_scale_in_fp32.reshape(128, 8, 4, 4).transpose(1, 2).contiguous() - scale_part.copy_(mat_scale_in_fp32.view(dtype=torch.float8_e4m3fn).reshape(128, 4, 128)) - - return mat_and_scale_in.contiguous() - - -class IntermediateMapper: - """Map the intermediate tensors to the corresponding variables.""" - - def __init__(self, intermediate_list: list[torch.Tensor]): - self.q = intermediate_list[0] - self.kv = intermediate_list[1] - self.ki = intermediate_list[2] - self.q_nope_down = intermediate_list[3] - self.q_pe = intermediate_list[4] - self.iq = intermediate_list[5] - self.iq_rt = intermediate_list[6] - self.idx_score = intermediate_list[7] - self.idx_logits = intermediate_list[8] - self.idx_sels = intermediate_list[9] - self.q_nope = intermediate_list[10] - self.o = intermediate_list[11] - self.o_acc = intermediate_list[12] - self.o_lse = intermediate_list[13] - self.o_lse_acc = intermediate_list[14] - self.proj_o = intermediate_list[15] - self.unproj_o = intermediate_list[16] - self.scores = intermediate_list[17] - self.x_mlp_in = intermediate_list[18] - self.exp_up_gate = intermediate_list[19] - self.sel_probs = intermediate_list[20] - self.sel_indices = intermediate_list[21] - self.exp_out = intermediate_list[22] - self.x_rmsnorm = intermediate_list[23] - self.logits_out = intermediate_list[24] - self.token_out = intermediate_list[25] - - -class BaseParams: - def __init__(self) -> None: - self._params: list[torch.Tensor] = [] - - def register_params(self, param: torch.Tensor) -> torch.Tensor: - self._params.append(param) - return param - - def get_params(self) -> list[torch.Tensor]: - return self._params - - @staticmethod - @abstractmethod - def num_params() -> int: - raise NotImplementedError("Subclasses must implement this method") - - -class PlaceHolderParams(BaseParams): - def __init__(self) -> None: - super().__init__() - - @staticmethod - def num_params() -> int: - return 0 - - -class EmbeddingParams(BaseParams): - def __init__( - self, - embedding: torch.Tensor, - freqs_cis: torch.Tensor, - ): - super().__init__() - self.embedding = self.register_params(embedding) - self.freqs_cis = self.register_params(freqs_cis) - - @staticmethod - def num_params() -> int: - return 1 - - def to_dict(self, device_id: int) -> dict[str, torch.Tensor]: - return {"model.embed_tokens.weight": self.embedding.to(device_id)} - - -class MlaParams(BaseParams): - def __init__( - self, - x_rmsnorm_gamma: torch.Tensor, - qkv_wa_weights: torch.Tensor, - qkv_wa_scales: torch.Tensor, - k_weights: torch.Tensor, - k_bias: torch.Tensor, - q_rmsnorm_gamma: torch.Tensor, - q_wb_weights: torch.Tensor, - q_wb_scales: torch.Tensor, - id_score_weights: torch.Tensor, - wkv_b1_weights: torch.Tensor, - wkv_b1_scales: torch.Tensor, - kv_rmsnorm_gamma: torch.Tensor, - wkv_b2_weights: torch.Tensor, - wkv_b2_scales: torch.Tensor, - unproj_weights: torch.Tensor, - unproj_scales: torch.Tensor, - ) -> None: - super().__init__() - self.x_rmsnorm_gamma = self.register_params(x_rmsnorm_gamma) - self.qkv_wa_weights = self.register_params(qkv_wa_weights) - self.qkv_wa_scales = self.register_params(qkv_wa_scales) - self.k_weights = self.register_params(k_weights) - self.k_bias = self.register_params(k_bias) - self.q_rmsnorm_gamma = self.register_params(q_rmsnorm_gamma) - self.q_wb_weights = self.register_params(q_wb_weights) - self.q_wb_scales = self.register_params(q_wb_scales) - self.id_score_weights = self.register_params(id_score_weights) - self.wkv_b1_weights = self.register_params(wkv_b1_weights) - self.wkv_b1_scales = self.register_params(wkv_b1_scales) - self.kv_rmsnorm_gamma = self.register_params(kv_rmsnorm_gamma) - self.wkv_b2_weights = self.register_params(wkv_b2_weights) - self.wkv_b2_scales = self.register_params(wkv_b2_scales) - self.unproj_weights = self.register_params(unproj_weights) - self.unproj_scales = self.register_params(unproj_scales) - - @staticmethod - def num_params() -> int: - return 16 - - def to_dict(self, layer_id: int, device_id: int) -> dict[str, torch.Tensor]: - return { - f"layer_{layer_id}_x_rmsnorm_gamma_dev_{device_id}": self.x_rmsnorm_gamma.to(device_id), - f"layer_{layer_id}_qkv_wa_weights_dev_{device_id}": self.qkv_wa_weights.to(device_id), - f"layer_{layer_id}_qkv_wa_scales_dev_{device_id}": self.qkv_wa_scales.to(device_id), - f"layer_{layer_id}_k_weights_dev_{device_id}": self.k_weights.to(device_id), - f"layer_{layer_id}_k_bias_dev_{device_id}": self.k_bias.to(device_id), - f"layer_{layer_id}_q_rmsnorm_gamma_dev_{device_id}": self.q_rmsnorm_gamma.to(device_id), - f"layer_{layer_id}_q_wb_weights_dev_{device_id}": self.q_wb_weights.to(device_id), - f"layer_{layer_id}_q_wb_scales_dev_{device_id}": self.q_wb_scales.to(device_id), - f"layer_{layer_id}_id_score_weights_dev_{device_id}": self.id_score_weights.to( - device_id - ), - f"layer_{layer_id}_wkv_b1_weights_dev_{device_id}": self.wkv_b1_weights.to(device_id), - f"layer_{layer_id}_wkv_b1_scales_dev_{device_id}": self.wkv_b1_scales.to(device_id), - f"layer_{layer_id}_kv_rmsnorm_gamma_dev_{device_id}": self.kv_rmsnorm_gamma.to( - device_id - ), - f"layer_{layer_id}_wkv_b2_weights_dev_{device_id}": self.wkv_b2_weights.to(device_id), - f"layer_{layer_id}_wkv_b2_scales_dev_{device_id}": self.wkv_b2_scales.to(device_id), - f"layer_{layer_id}_unproj_weights_dev_{device_id}": self.unproj_weights.to(device_id), - f"layer_{layer_id}_unproj_scales_dev_{device_id}": self.unproj_scales.to(device_id), - } - - -class MLPParams(BaseParams): - def __init__( - self, - unproj_o_gamma: torch.Tensor, - upgate_weights: torch.Tensor, - upgate_scales: torch.Tensor, - down_weights: torch.Tensor, - down_scales: torch.Tensor, - ) -> None: - super().__init__() - self.unproj_o_gamma = self.register_params(unproj_o_gamma) - self.upgate_weights = self.register_params(upgate_weights) - self.upgate_scales = self.register_params(upgate_scales) - self.down_weights = self.register_params(down_weights) - self.down_scales = self.register_params(down_scales) - - @staticmethod - def num_params() -> int: - return 5 - - def to_dict(self, layer_id: int, device_id: int) -> dict[str, torch.Tensor]: - return { - f"layer_{layer_id}_unproj_o_gamma_dev_{device_id}": self.unproj_o_gamma.to(device_id), - f"layer_{layer_id}_upgate_weights_dev_{device_id}": self.upgate_weights.to(device_id), - f"layer_{layer_id}_upgate_scales_dev_{device_id}": self.upgate_scales.to(device_id), - f"layer_{layer_id}_down_weights_dev_{device_id}": self.down_weights.to(device_id), - f"layer_{layer_id}_down_scales_dev_{device_id}": self.down_scales.to(device_id), - } - - -class MoEParams(BaseParams): - def __init__( - self, - unproj_o_gamma: torch.Tensor, - exp_proj_weights: torch.Tensor, - exp_bias: torch.Tensor, - exp_upgate_weights: torch.Tensor, - exp_upgate_scales: torch.Tensor, - exp_down_weights: torch.Tensor, - exp_down_scales: torch.Tensor, - ) -> None: - super().__init__() - self.unproj_o_gamma = self.register_params(unproj_o_gamma) - self.exp_proj_weights = self.register_params(exp_proj_weights) - self.exp_bias = self.register_params(exp_bias) - self.exp_upgate_weights = self.register_params(exp_upgate_weights) - self.exp_upgate_scales = self.register_params(exp_upgate_scales) - self.exp_down_weights = self.register_params(exp_down_weights) - self.exp_down_scales = self.register_params(exp_down_scales) - - @staticmethod - def num_params() -> int: - return 7 - - def to_dict(self, layer_id: int, device_id: int) -> dict[str, torch.Tensor]: - return { - f"layer_{layer_id}_unproj_o_gamma_dev_{device_id}": self.unproj_o_gamma.to(device_id), - f"layer_{layer_id}_exp_proj_weights_dev_{device_id}": self.exp_proj_weights.to( - device_id - ), - f"layer_{layer_id}_exp_bias_dev_{device_id}": self.exp_bias.to(device_id), - f"layer_{layer_id}_exp_upgate_weights_dev_{device_id}": self.exp_upgate_weights.to( - device_id - ), - f"layer_{layer_id}_exp_upgate_scales_dev_{device_id}": self.exp_upgate_scales.to( - device_id - ), - f"layer_{layer_id}_exp_down_weights_dev_{device_id}": self.exp_down_weights.to( - device_id - ), - f"layer_{layer_id}_exp_down_scales_dev_{device_id}": self.exp_down_scales.to(device_id), - } - - -class LLMHeadParams(BaseParams): - """LLM Head Parameters""" - - def __init__( - self, - hidden_rms_gamma: torch.Tensor, - head_proj_weights: torch.Tensor, - ) -> None: - super().__init__() - self.hidden_rms_gamma = self.register_params(hidden_rms_gamma) - self.head_proj_weights = self.register_params(head_proj_weights) - - @staticmethod - def num_params() -> int: - return 2 - - def to_dict(self, layer_id: int, device_id: int) -> dict[str, torch.Tensor]: - return { - f"layer_{layer_id}_model.norm.weight_dev_{device_id}": self.hidden_rms_gamma.to( - device_id - ), - f"layer_{layer_id}_lm_head.weight_dev_{device_id}": self.head_proj_weights.to( - device_id - ), - } - - -class MTPPreprocessParams(BaseParams): - def __init__( - self, - embedding_rmsnorm_gamma: torch.Tensor, - hidden_rmsnorm_gamma: torch.Tensor, - eh_proj_weights: torch.Tensor, - ) -> None: - super().__init__() - self.embedding_rmsnorm_gamma = self.register_params(embedding_rmsnorm_gamma) - self.hidden_rmsnorm_gamma = self.register_params(hidden_rmsnorm_gamma) - self.eh_proj_weights = self.register_params(eh_proj_weights) - - @staticmethod - def num_params() -> int: - return 3 - - def to_dict(self, layer_id: int, device_id: int) -> dict[str, torch.Tensor]: - return { - f"layer_{layer_id}_embedding_rmsnorm_gamma_dev_{device_id}": ( - self.embedding_rmsnorm_gamma.to(device_id) - ), - f"layer_{layer_id}_hidden_rmsnorm_gamma_dev_{device_id}": self.hidden_rmsnorm_gamma.to( - device_id - ), - f"layer_{layer_id}_eh_proj_weights_dev_{device_id}": self.eh_proj_weights.to(device_id), - } - - -class TempVars(BaseParams): - def __init__( - self, - q: torch.Tensor, - kv: torch.Tensor, - ki: torch.Tensor, - q_nope_down: torch.Tensor, - q_pe: torch.Tensor, - iq: torch.Tensor, - iq_rt: torch.Tensor, - idx_score: torch.Tensor, - idx_logits: torch.Tensor, - idx_sels: torch.Tensor, - q_nope: torch.Tensor, - o: torch.Tensor, - o_acc: torch.Tensor, - o_lse: torch.Tensor, - o_lse_acc: torch.Tensor, - proj_o: torch.Tensor, - unproj_o: torch.Tensor, - scores: torch.Tensor, - x_mlp_in: torch.Tensor, - exp_up_gate: torch.Tensor, - sel_probs: torch.Tensor, - sel_indices: torch.Tensor, - exp_out: torch.Tensor, - x_rmsnorm: torch.Tensor, - logits_out: torch.Tensor, - token_out: torch.Tensor, - embedding_rmsnorm: torch.Tensor, - hidden_rmsnorm: torch.Tensor, - eh_proj: torch.Tensor, - x_tensor: torch.Tensor, - rope_freqs: torch.Tensor, - cur_pos: torch.Tensor, - token_id: torch.Tensor, - last_hidden_states: torch.Tensor, - draft_tokens: torch.Tensor, - predicted_tokens: torch.Tensor, - predicted_hidden: torch.Tensor, - accepted_tokens: torch.Tensor, - next_draft_tokens: torch.Tensor, - ) -> None: - super().__init__() - self.q = self.register_params(q) - self.kv = self.register_params(kv) - self.ki = self.register_params(ki) - self.q_nope_down = self.register_params(q_nope_down) - self.q_pe = self.register_params(q_pe) - self.iq = self.register_params(iq) - self.iq_rt = self.register_params(iq_rt) - self.idx_score = self.register_params(idx_score) - self.idx_logits = self.register_params(idx_logits) - self.idx_sels = self.register_params(idx_sels) - self.q_nope = self.register_params(q_nope) - self.o = self.register_params(o) - self.o_acc = self.register_params(o_acc) - self.o_lse = self.register_params(o_lse) - self.o_lse_acc = self.register_params(o_lse_acc) - self.proj_o = self.register_params(proj_o) - self.unproj_o = self.register_params(unproj_o) - self.scores = self.register_params(scores) - self.x_mlp_in = self.register_params(x_mlp_in) - self.exp_up_gate = self.register_params(exp_up_gate) - self.sel_probs = self.register_params(sel_probs) - self.sel_indices = self.register_params(sel_indices) - self.exp_out = self.register_params(exp_out) - self.x_rmsnorm = self.register_params(x_rmsnorm) - self.logits_out = self.register_params(logits_out) - self.token_out = self.register_params(token_out) - self.embedding_rmsnorm = self.register_params(embedding_rmsnorm) - self.hidden_rmsnorm = self.register_params(hidden_rmsnorm) - self.eh_proj = self.register_params(eh_proj) - self.x_tensor = self.register_params(x_tensor) - self.rope_freqs = self.register_params(rope_freqs) - self.cur_pos = self.register_params(cur_pos) - self.token_id = self.register_params(token_id) - self.last_hidden_states = self.register_params(last_hidden_states) - self.draft_tokens = self.register_params(draft_tokens) - self.predicted_tokens = self.register_params(predicted_tokens) - self.predicted_hidden = self.register_params(predicted_hidden) - self.accepted_tokens = self.register_params(accepted_tokens) - self.next_draft_tokens = self.register_params(next_draft_tokens) - - @staticmethod - def num_params() -> int: - return 39 - - def tot_size_in_bytes_aligned(self, aligned_size: int) -> int: - tot_size: int = 0 - for param in self._params: - aligned_param_size = (param.nbytes + aligned_size - 1) // aligned_size * aligned_size - tot_size += aligned_param_size - return tot_size - - def generate_params_with_continuous_storage( - self, device: torch.device, aligned_size: int = 1024 - ) -> list[torch.Tensor]: - tot_size = self.tot_size_in_bytes_aligned(aligned_size) - cloned_params = [] - large_tensor = torch.zeros(tot_size, device=device, dtype=torch.uint8) - offset = 0 - for param in self._params: - aligned_param_size = (param.nbytes + aligned_size - 1) // aligned_size * aligned_size - cloned_params.append( - large_tensor[offset : offset + param.nbytes].view(param.dtype).view(param.shape) - ) - offset += aligned_param_size - return cloned_params - - -class CacheVars(BaseParams): - def __init__( - self, - k_cache: torch.Tensor, - kv_cache: torch.Tensor, - pe_cache: torch.Tensor, - ) -> None: - super().__init__() - self.k_cache = self.register_params(k_cache) - self.kv_cache = self.register_params(kv_cache) - self.pe_cache = self.register_params(pe_cache) - - @staticmethod - def num_params() -> int: - return 3 - - -class Dsa671BModelInitializer: - """DSA with MTP e2e model for DeepSeek v3.2""" - - # TODO: These parameters should be carefully checked - BATCH_SIZE = 1 - MAX_SEQ_LEN = 4 - NUM_HEADS = 16 - NUM_KI_HEADS = 64 - - MAX_OPS = 2048 - MAX_SEL_TOKENS = 2048 - MAX_CTX_LEN = 16384 - NUM_DENSE_LAYERS = 3 - NUM_MOE_LAYERS = 58 - NUM_LAYERS = NUM_DENSE_LAYERS + NUM_MOE_LAYERS - - HIDDEN_SIZE = 7168 - PE_LORA_DIM = 64 - Q_NOPE_DIM = 128 - - Q_DIM = 1536 - KV_CACHE_DIM = 512 - PE_CACHE_DIM = 64 - KI_CACHE_DIM = 128 - Q_PE_DIM = 512 - V_HEAD_DIM = 128 - - N_ROUTED_EXPERTS = 256 - N_ACTIVATE_EXPERTS = 8 - N_TOTAL_EXPERTS = N_ACTIVATE_EXPERTS + 1 - EXP_DIMS = 256 - - FULL_VOCAB_SIZE = 129280 - VOCAB_SIZE = FULL_VOCAB_SIZE // 8 # 16160 - - def __init__( - self, - device: torch.device, - max_seq_len: int | None = None, - max_ctx_len: int | None = None, - with_weight_conversion: bool = True, - with_mtp: bool = False, - ) -> None: - super().__init__() - - self.device = device - self.max_seq_len = max_seq_len if max_seq_len is not None else self.MAX_SEQ_LEN - self.max_ctx_len = max_ctx_len if max_ctx_len is not None else self.MAX_CTX_LEN - self.with_weight_conversion = with_weight_conversion - self.with_mtp = with_mtp - - self.bf16_desc = {"dtype": torch.bfloat16, "device": device} - self.fp16_desc = {"dtype": torch.float16, "device": device} - self.fp32_desc = {"dtype": torch.float32, "device": device} - self.uint64_desc = {"dtype": torch.uint64, "device": device} - self.int32_desc = {"dtype": torch.int32, "device": device} - self.uint8_desc = {"dtype": torch.uint8, "device": device} - - self.mtp_params_sidx = 0 - - def register_weights_and_scales( - self, dim1: int, dim2: int - ) -> tuple[torch.Tensor, torch.Tensor]: - block_size = 128 - weights_dims = (dim1, dim2) - weights = torch.randn(weights_dims, **self.bf16_desc).to(torch.float8_e4m3fn) - scales = torch.randn((dim1, dim2 // block_size), **self.bf16_desc) - return weights, scales - - def init_llm_head_params(self) -> LLMHeadParams: - from tilert.models.preprocess.weight_utils import RMSNormHeadProjWeightsConverter - - hidden_rms_gamma_shape = (self.HIDDEN_SIZE,) - head_proj_weights_shape = (self.VOCAB_SIZE, self.HIDDEN_SIZE) - - hidden_rms_gamma = torch.randn(hidden_rms_gamma_shape, **self.fp32_desc) - head_proj_weights = torch.randn(head_proj_weights_shape, **self.bf16_desc) - - if self.with_weight_conversion: - # Apply weight conversion for LLM head - head_proj_weights = ( - RMSNormHeadProjWeightsConverter.tilert_to_tilert_native_bf16_warp_gemv( - head_proj_weights - ) - ) - - return LLMHeadParams( - hidden_rms_gamma, - head_proj_weights, - ) - - def init_embedding_params(self) -> EmbeddingParams: - embedding = torch.randn(self.FULL_VOCAB_SIZE, self.HIDDEN_SIZE, **self.bf16_desc) - freqs_cis = precompute_freqs_cis(ModelArgsV3_2()) - freqs_cis = torch.view_as_real(freqs_cis).reshape(freqs_cis.shape[0], -1) - return EmbeddingParams(embedding, freqs_cis.to(self.device)) - - def init_mla_params(self) -> MlaParams: - from tilert.models.preprocess.weight_utils import ( - RMSNormProjQAKVAKIWeightsConverter, - ) - - qkv_dim = self.Q_DIM + self.KV_CACHE_DIM + self.PE_LORA_DIM + 128 - x_rmsnorm_gamma_shape = (self.HIDDEN_SIZE,) - q_wb_shape = ((self.PE_LORA_DIM + self.Q_NOPE_DIM + 512) * self.NUM_HEADS, self.Q_DIM) - wkv_b1_shape = (self.NUM_HEADS, self.KV_CACHE_DIM, self.V_HEAD_DIM) - wkv_b2_shape = (self.NUM_HEADS, self.V_HEAD_DIM, self.KV_CACHE_DIM) - wkv_b2_scales_shape = (self.NUM_HEADS, self.V_HEAD_DIM // 128, self.KV_CACHE_DIM // 128) - unproj_w_shape = (self.HIDDEN_SIZE, self.NUM_HEADS * self.V_HEAD_DIM) - unproj_scales_shape = (896, self.NUM_HEADS * self.V_HEAD_DIM // 128) - - x_rmsnorm_gamma = torch.randn(x_rmsnorm_gamma_shape, **self.fp32_desc) - qkv_wa_weights, _ = self.register_weights_and_scales(qkv_dim, self.HIDDEN_SIZE) - qkv_wa_scales = torch.randn((130, 64), **self.bf16_desc) - k_weights = torch.randn(128, **self.fp32_desc) - k_bias = torch.randn(128, **self.fp32_desc) - q_rmsnorm_gamma = torch.randn(self.Q_DIM, **self.fp32_desc) - q_wb_weights, _ = self.register_weights_and_scales(*q_wb_shape) - q_wb_scales = torch.randn((448, 12), **self.bf16_desc) - id_score_weights = torch.randn(64, self.HIDDEN_SIZE, **self.bf16_desc) - wkv_b1_weights = torch.randn(wkv_b1_shape, **self.fp16_desc).to(torch.float8_e4m3fn) - wkv_b1_scales = torch.randn((16, 8, 1), **self.bf16_desc) - kv_rmsnorm_gamma = torch.randn(self.KV_CACHE_DIM, **self.fp32_desc) - wkv_b2_weights = torch.randn(wkv_b2_shape, **self.fp16_desc).to(torch.float8_e4m3fn) - wkv_b2_scales = torch.randn(wkv_b2_scales_shape, **self.bf16_desc) - unproj_weights = torch.randn(unproj_w_shape, **self.fp16_desc).to(torch.float8_e4m3fn) - unproj_scales = torch.randn(unproj_scales_shape, **self.bf16_desc) - - if self.with_weight_conversion: - # Apply weight conversion for MLA qkv_wa weights - # Convert tilert format -> common format -> tilert native bf16 warp gemv format - common_weights = RMSNormProjQAKVAKIWeightsConverter.tilert_to_common( - qkv_wa_weights, qkv_wa_scales, x_rmsnorm_gamma - ) - qkv_wa_weights, x_rmsnorm_gamma = ( - RMSNormProjQAKVAKIWeightsConverter.common_to_tilert_native_bf16_warp_gemv( - *common_weights - ) - ) - - return MlaParams( - x_rmsnorm_gamma=x_rmsnorm_gamma, - qkv_wa_weights=qkv_wa_weights, - qkv_wa_scales=qkv_wa_scales, - k_weights=k_weights, - k_bias=k_bias, - q_rmsnorm_gamma=q_rmsnorm_gamma, - q_wb_weights=q_wb_weights, - q_wb_scales=q_wb_scales, - id_score_weights=id_score_weights, - wkv_b1_weights=wkv_b1_weights, - wkv_b1_scales=wkv_b1_scales, - kv_rmsnorm_gamma=kv_rmsnorm_gamma, - wkv_b2_weights=wkv_b2_weights, - wkv_b2_scales=wkv_b2_scales, - unproj_weights=unproj_weights, - unproj_scales=unproj_scales, - ) - - def init_mlp_params(self) -> MLPParams: - exp_upgate_w_shape = (9, self.EXP_DIMS * 2, self.HIDDEN_SIZE) - exp_upgate_s_shape = (9, self.EXP_DIMS * 2 // 128, 64) - exp_down_w_shape = (9, self.HIDDEN_SIZE, self.EXP_DIMS) - exp_down_s_shape = (9, 1024, self.EXP_DIMS // 128) - - unproj_o_gamma = torch.randn(self.HIDDEN_SIZE, **self.fp32_desc) - upgate_weights = torch.randn(exp_upgate_w_shape, **self.fp16_desc).to(torch.float8_e4m3fn) - upgate_scales = torch.randn(exp_upgate_s_shape, **self.bf16_desc) - down_weights = torch.randn(exp_down_w_shape, **self.fp16_desc).to(torch.float8_e4m3fn) - down_scales = torch.randn(exp_down_s_shape, **self.bf16_desc) - - return MLPParams( - unproj_o_gamma, - upgate_weights, - upgate_scales, - down_weights, - down_scales, - ) - - def init_moe_params(self) -> MoEParams: - from tilert.models.preprocess.weight_utils import ( - ExpertSelectUpGateSiLUWeightsConverter, - ) - - exp_ug_w_shape = (self.N_ROUTED_EXPERTS + 1, self.EXP_DIMS * 2, self.HIDDEN_SIZE) - exp_upgate_s_shape = (self.N_ROUTED_EXPERTS + 1, self.EXP_DIMS * 2 // 128, 64) - exp_down_w_shape = (self.N_ROUTED_EXPERTS + 1, self.HIDDEN_SIZE, self.EXP_DIMS) - exp_down_s_shape = (self.N_ROUTED_EXPERTS + 1, 1024, self.EXP_DIMS // 128) - - unproj_o_gamma = torch.randn(self.HIDDEN_SIZE, **self.fp32_desc) - exp_proj_weights = torch.randn((self.N_ROUTED_EXPERTS, self.HIDDEN_SIZE), **self.bf16_desc) - exp_bias = torch.randn(self.N_ROUTED_EXPERTS, **self.fp32_desc) - exp_upgate_weights = torch.randn(exp_ug_w_shape, **self.fp16_desc).to(torch.float8_e4m3fn) - exp_upgate_scales = torch.randn(exp_upgate_s_shape, **self.bf16_desc) - exp_down_weights = torch.randn(exp_down_w_shape, **self.fp16_desc).to(torch.float8_e4m3fn) - exp_down_scales = torch.randn(exp_down_s_shape, **self.bf16_desc) - - if self.with_weight_conversion: - # Apply weight conversion for MOE exp_upgate weights - exp_upgate_weights = ExpertSelectUpGateSiLUWeightsConverter.tilert_to_tilert_144sm_mma( - exp_upgate_weights, exp_upgate_scales - ) - - return MoEParams( - unproj_o_gamma, - exp_proj_weights, - exp_bias, - exp_upgate_weights, - exp_upgate_scales, - exp_down_weights, - exp_down_scales, - ) - - def init_mtp_preprocess_params(self) -> MTPPreprocessParams: - """Initialize MTP preprocess parameters with random values.""" - embedding_rmsnorm_gamma = torch.randn(self.HIDDEN_SIZE, **self.fp32_desc) - hidden_rmsnorm_gamma = torch.randn(self.HIDDEN_SIZE, **self.fp32_desc) - eh_proj_weights = torch.randn((128, 7, 56, 256), **self.bf16_desc) - return MTPPreprocessParams( - embedding_rmsnorm_gamma, - hidden_rmsnorm_gamma, - eh_proj_weights, - ) - - def acquire_params(self) -> list[torch.Tensor]: - params = [] - - for _ in range(self.NUM_DENSE_LAYERS): - params.extend(self.init_mla_params().get_params()) - params.extend(self.init_mlp_params().get_params()) - - for _ in range(self.NUM_MOE_LAYERS): - params.extend(self.init_mla_params().get_params()) - params.extend(self.init_moe_params().get_params()) - - params.extend(self.init_llm_head_params().get_params()) - params.extend(self.init_embedding_params().get_params()) - - if self.with_mtp: - self.mtp_params_sidx = len(params) - params.extend(self.init_embedding_params().get_params()) - params.extend(self.init_mtp_preprocess_params().get_params()) - params.extend(self.init_mla_params().get_params()) - params.extend(self.init_moe_params().get_params()) - params.extend(self.init_llm_head_params().get_params()) - - return params - - def acquire_temp_vars(self, seq_len: int | None = None) -> TempVars: - """Acquire temporary variables for the model. - - Args: - seq_len: Sequence length for temp vars. If None, uses self.max_seq_len. - - Returns: - TempVars object containing all temporary tensors. - """ - seq_len = seq_len if seq_len is not None else self.max_seq_len - BATCH_SEQ = (self.BATCH_SIZE, seq_len) - - q = torch.zeros(*BATCH_SEQ, self.Q_DIM, **self.bf16_desc) - kv = torch.zeros(*BATCH_SEQ, self.KV_CACHE_DIM, **self.bf16_desc) - q_pe = torch.zeros(*BATCH_SEQ, self.NUM_HEADS, self.PE_LORA_DIM, **self.bf16_desc) - ki = torch.zeros(*BATCH_SEQ, self.KI_CACHE_DIM, **self.bf16_desc) - q_nope_down = torch.zeros(*BATCH_SEQ, self.NUM_HEADS, self.V_HEAD_DIM, **self.bf16_desc) - q_nope = torch.zeros(*BATCH_SEQ, self.NUM_HEADS, self.Q_PE_DIM, **self.bf16_desc) - iq = torch.zeros(*BATCH_SEQ, self.NUM_KI_HEADS, self.KI_CACHE_DIM, **self.bf16_desc) - iq_rt = torch.zeros(*BATCH_SEQ, self.NUM_KI_HEADS, self.KI_CACHE_DIM, **self.bf16_desc) - idx_score = torch.zeros(*BATCH_SEQ, self.NUM_KI_HEADS, **self.bf16_desc) - idx_logits = torch.zeros(*BATCH_SEQ, self.max_ctx_len, **self.fp32_desc) - idx_sels = torch.zeros(*BATCH_SEQ, self.MAX_SEL_TOKENS, **self.int32_desc) - o = torch.zeros(*BATCH_SEQ, self.NUM_HEADS, self.KV_CACHE_DIM, **self.bf16_desc) - o_acc = torch.zeros(*BATCH_SEQ, self.NUM_HEADS, 32, self.KV_CACHE_DIM, **self.fp32_desc) - o_lse = torch.empty(*BATCH_SEQ, self.NUM_HEADS, **self.fp32_desc) - o_lse_acc = torch.empty(*BATCH_SEQ, self.NUM_HEADS, 32, **self.fp32_desc) - proj_o = torch.zeros(*BATCH_SEQ, self.NUM_HEADS, self.V_HEAD_DIM, **self.bf16_desc) - unproj_o = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc) - scores = torch.zeros(*BATCH_SEQ, self.N_ROUTED_EXPERTS, **self.fp32_desc) - x_mlp_in = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc) - exp_up_gate = torch.zeros(*BATCH_SEQ, self.N_TOTAL_EXPERTS, self.EXP_DIMS, **self.bf16_desc) - sel_probs = torch.zeros(*BATCH_SEQ, self.N_ACTIVATE_EXPERTS, **self.fp32_desc) - sel_indices = torch.zeros(*BATCH_SEQ, self.N_ACTIVATE_EXPERTS, **self.int32_desc) - exp_out = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc) - x_rmsnorm = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc) - logits_out = torch.zeros(*BATCH_SEQ, self.VOCAB_SIZE, **self.fp32_desc) - token_out = torch.zeros(*BATCH_SEQ, 1, **self.int32_desc) - - embedding_rmsnorm = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc) - hidden_rmsnorm = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc) - eh_proj = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc) - x_tensor = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc) - rope_freqs = torch.zeros(*BATCH_SEQ, self.PE_CACHE_DIM, **self.fp32_desc) - cur_pos = torch.zeros(self.BATCH_SIZE, **self.int32_desc) - token_id = torch.zeros(*BATCH_SEQ, 1, **self.int32_desc) - last_hidden_states = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc) - - draft_tokens = torch.zeros(*BATCH_SEQ, **self.int32_desc) - predicted_tokens = torch.zeros(*BATCH_SEQ, 1, **self.int32_desc) - predicted_hidden = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc) - accepted_tokens = torch.zeros(self.BATCH_SIZE, **self.int32_desc) - next_draft_tokens = torch.zeros(*BATCH_SEQ, **self.int32_desc) - - return TempVars( - q, - kv, - ki, - q_nope_down, - q_pe, - iq, - iq_rt, - idx_score, - idx_logits, - idx_sels, - q_nope, - o, - o_acc, - o_lse, - o_lse_acc, - proj_o, - unproj_o, - scores, - x_mlp_in, - exp_up_gate, - sel_probs, - sel_indices, - exp_out, - x_rmsnorm, - logits_out, - token_out, - embedding_rmsnorm, - hidden_rmsnorm, - eh_proj, - x_tensor, - rope_freqs, - cur_pos, - token_id, - last_hidden_states, - draft_tokens, - predicted_tokens, - predicted_hidden, - accepted_tokens, - next_draft_tokens, - ) - - def acquire_cache_vars(self, num_layers: int | None = None) -> list[torch.Tensor]: - """Acquire cache variables for the model. - - Args: - num_layers: Number of layers to create cache for. If None, uses NUM_LAYERS. - - Returns: - List of cache tensors (3 tensors per layer: k_cache, kv_cache, pe_cache). - """ - num_layers = num_layers if num_layers is not None else self.NUM_LAYERS - if self.with_mtp: - num_layers += 1 - - BATCH_CTX = (self.BATCH_SIZE, self.max_ctx_len) - cache_vars = [] - for _ in range(num_layers): - cache_vars.extend( - [ - torch.zeros(*BATCH_CTX, self.KI_CACHE_DIM, **self.bf16_desc), - torch.zeros(*BATCH_CTX, self.KV_CACHE_DIM, **self.bf16_desc), - torch.zeros(*BATCH_CTX, self.PE_CACHE_DIM, **self.bf16_desc), - ] - ) - return cache_vars - - def acquire_single_layer_cache_vars(self) -> list[torch.Tensor]: - """Acquire cache variables for a single layer. - - Returns: - List of 3 cache tensors: k_cache, kv_cache, pe_cache. - """ - return self.acquire_cache_vars(num_layers=1) - - def acquire_misc_vars(self) -> list[torch.Tensor]: - return [ - torch.zeros(self.MAX_OPS, 148, 16, **self.uint64_desc), - torch.zeros(self.MAX_OPS, 128, **self.uint8_desc), - torch.zeros(self.MAX_OPS, 8, **self.uint8_desc), - ] - - def get_mtp_all_vars(self) -> list[torch.Tensor]: - return [ - self.acquire_params()[self.mtp_params_sidx :], - self.acquire_temp_vars().get_params(), - self.acquire_cache_vars()[-3:], - # Potential issue: Reallocate misc vars for MTP - self.acquire_misc_vars(), - ] diff --git a/python/models/deepseek_v3_2/refs/__init__.py b/python/models/deepseek_v3_2/refs/__init__.py new file mode 100644 index 0000000..25e6872 --- /dev/null +++ b/python/models/deepseek_v3_2/refs/__init__.py @@ -0,0 +1,13 @@ +"""DeepSeek v3.2 reference kernels (tilelang/triton implementations). + +This package exposes helpers like `act_quant`, `fp8_gemm`, and `weight_dequant` +for tests and higher-level Python ops. +""" + +from .kernel import act_quant, fp8_gemm, weight_dequant + +__all__ = [ + "act_quant", + "fp8_gemm", + "weight_dequant", +] diff --git a/python/models/deepseek_v3_2/refs/kernel.py b/python/models/deepseek_v3_2/refs/kernel.py new file mode 100644 index 0000000..eb5e274 --- /dev/null +++ b/python/models/deepseek_v3_2/refs/kernel.py @@ -0,0 +1,354 @@ +try: + import tilelang + import tilelang.language as T +except ImportError: + raise ImportError("Cannot import tilelang, please install tilelang.") from None + + +import torch +import triton +import triton.language as tl + +__all__ = [ + "weight_dequant", + "act_quant", + "fp8_gemm", + "fp8_index", +] + +tilelang.set_log_level("WARNING") + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + # tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, +} + +FP8 = "float8_e4m3" +BF16 = "bfloat16" +FP32 = "float32" + + +def fast_log2_ceil(x): # type: ignore + bits_x = T.reinterpret("uint32", x) + exp_x = (bits_x >> 23) & 0xFF + man_bits = bits_x & ((1 << 23) - 1) + return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) + + +def fast_pow2(x): # type: ignore + bits_x = (x + 127) << 23 + return T.reinterpret("float32", bits_x) + + +def fast_round_scale(amax, fp8_max_inv): # type: ignore + return fast_pow2(fast_log2_ceil(amax * fp8_max_inv)) + + +@triton.jit +def weight_dequant_kernel( # type: ignore + x_ptr, + s_ptr, + y_ptr, + M_Size: tl.constexpr, + N_Size: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +) -> None: + """ + Weight dequantization kernel. + + Dequantizes weights using the provided scaling factors and stores the + result. + + Args: + x_ptr (tl.pointer): Pointer to the quantized weights. + s_ptr (tl.pointer): Pointer to the scaling factors. + y_ptr (tl.pointer): Pointer to the output buffer for dequantized + weights. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n_size = tl.cdiv(N_Size, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N_Size + offs_n[None, :] + mask = (offs_m[:, None] < M_Size) & (offs_n[None, :] < N_Size) + x_in = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s_in = tl.load(s_ptr + pid_m * n_size + pid_n) + y_out = x_in * s_in + tl.store(y_ptr + offs, y_out, mask=mask) + + +def weight_dequant(x_in: torch.Tensor, s_in: torch.Tensor, block_size: int = 128) -> torch.Tensor: + """ + Dequantizes the given weight tensor using the provided scale tensor. + + Args: + x_in (torch.Tensor): The quantized weight tensor of shape (M, N). + s_in (torch.Tensor): The scale tensor of shape (M//block_size, + N//block_size). + block_size (int, optional): The block size to use for dequantization. + Defaults to 128. + + Returns: + torch.Tensor: The dequantized weight tensor of the same shape as `x`. + + Raises: + AssertionError: If `x` or `s` are not contiguous or if their dimensions + are not 2. + """ + assert x_in.is_contiguous() and s_in.is_contiguous(), "Input tensors must be contiguous" + assert x_in.dim() == 2 and s_in.dim() == 2, "Input tensors must have 2 dimensions" + M_Size, N_Size = x_in.size() + y_out = torch.empty_like(x_in, dtype=torch.get_default_dtype()) + grid = lambda meta: ( # noqa: E731 + triton.cdiv(M_Size, meta["BLOCK_SIZE"]), + triton.cdiv(N_Size, meta["BLOCK_SIZE"]), + ) + weight_dequant_kernel[grid](x_in, s_in, y_out, M_Size, N_Size, BLOCK_SIZE=block_size) + return y_out + + +@tilelang.jit(pass_configs=pass_configs) +def act_quant_kernel( # type: ignore + N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False # type: ignore +): # type: ignore + M = T.symbolic("M") + fp8_min = -448.0 + fp8_max = 448.0 + fp8_max_inv = 1 / fp8_max + num_stages = 0 if round_scale else 2 + blk_m = 32 + group_size = 128 + + @T.prim_func + def act_quant_kernel_( # type: ignore + X: T.Tensor[(M, N), in_dtype], + Y: T.Tensor[(M, N), out_dtype], + S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype], + ): # type: ignore + with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as ( + pid_m, + pid_n, + ): + x_shared = T.alloc_shared((blk_m, group_size), in_dtype) + x_local = T.alloc_fragment((blk_m, group_size), in_dtype) + amax_local = T.alloc_fragment((blk_m,), scale_dtype) + s_local = T.alloc_fragment((blk_m,), scale_dtype) + y_local = T.alloc_fragment((blk_m, group_size), out_dtype) + y_shared = T.alloc_shared((blk_m, group_size), out_dtype) + + for _ in T.Pipelined(1, num_stages=num_stages): + T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared) + T.copy(x_shared, x_local) + T.reduce_absmax(x_local, amax_local, dim=1) + for i in T.Parallel(blk_m): + amax_local[i] = T.max(amax_local[i], 1e-4) + if round_scale: + s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv) + else: + s_local[i] = amax_local[i] * fp8_max_inv + for i, j in T.Parallel(blk_m, group_size): + y_local[i, j] = T.clamp(x_local[i, j] / s_local[i], fp8_min, fp8_max) + for i in T.Parallel(blk_m): + S[pid_m * blk_m + i, pid_n] = s_local[i] + T.copy(y_local, y_shared) + T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size]) + + return act_quant_kernel_ + + +def act_quant( + x: torch.Tensor, block_size: int = 128, scale_fmt: str | None = None +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization. + + Args: + x (torch.Tensor): The input tensor to be quantized. + Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. + Default is 128. + scale_fmt (Optional[str], optional): The format of the scale. Default is None. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ + assert x.is_contiguous(), "Input tensor must be contiguous" + assert ( + x.size(-1) % block_size == 0 + ), f"Last dimension size must be divisible by block_size (block_size={block_size})" + N = x.size(-1) + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) + kernel = act_quant_kernel(N, round_scale=scale_fmt is not None) + kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size)) + return y, s + + +@tilelang.jit(pass_configs=pass_configs) +def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"): # type: ignore + assert out_dtype in [BF16, "float32"] + + M = T.symbolic("M") + group_size = 128 + block_M = 32 + block_N = 128 + block_K = 128 + + @T.prim_func + def fp8_gemm_kernel_( # type: ignore + A: T.Tensor[(M, K), FP8], + B: T.Tensor[(N, K), FP8], + C: T.Tensor[(M, N), out_dtype], + scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), FP32], + scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), FP32], + ): # type: ignore + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( + bx, + by, + ): + A_shared = T.alloc_shared((block_M, block_K), FP8) + B_shared = T.alloc_shared((block_N, block_K), FP8) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + Scale_C_shared = T.alloc_shared((block_M), FP32) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + T.clear(C_local_accum) + K_iters = T.ceildiv(K, block_K) + for k in T.Pipelined(K_iters, num_stages=4): + # Load A into shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + # Load B into shared memory + T.copy(B[bx * block_N, k * block_K], B_shared) + # Load scale into shared memory + Scale_B = scales_b[bx * block_N // group_size, k] + for i in T.Parallel(block_M): + Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B + + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + # Promote to enable 2xAcc + for i, j in T.Parallel(block_M, block_N): + C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i] + T.clear(C_local) + # TMA store + T.copy(C_local_accum, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return fp8_gemm_kernel_ + + +def fp8_gemm( + a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor +) -> torch.Tensor: + """ + Perform a matrix multiplication using FP8 precision. + + Args: + a (torch.Tensor): The first input matrix, must be contiguous. + a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. + b (torch.Tensor): The second input matrix, must be contiguous. + b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous" + assert a_s.is_contiguous() and b_s.is_contiguous(), "Scaling factor tensors must be contiguous" + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + kernel = fp8_gemm_kernel(N, K) + kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s) + return c + + +@tilelang.jit(out_idx=[4], pass_configs=pass_configs) +def fp8_index_kernel(h: int, d: int): # type: ignore + b = T.symbolic("b") + m = T.symbolic("m") + n = T.symbolic("n") + + blk_n1 = 512 + blk_n2 = 128 + + @T.prim_func + def fp8_index_kernel_( + q: T.Tensor[(b, m, h, d), FP8], + q_s: T.Tensor[(b, m, h), FP32], + k: T.Tensor[(b, n, d), FP8], + k_s: T.Tensor[(b, n), FP32], + o: T.Tensor[(b, m, n), FP32], + ) -> None: + with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n): + q_smem = T.alloc_shared((h, d), FP8) + T.copy(q[i_b, i_m, 0, 0], q_smem) + + q_s_frag = T.alloc_fragment(h, FP32) + T.copy(q_s[i_b, i_m, 0], q_s_frag) + + for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2): + k_smem = T.alloc_shared((blk_n2, d), FP8) + T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem) + + k_s_frag = T.alloc_fragment(blk_n2, FP32) + T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag) + + logits = T.alloc_fragment((blk_n2, h), FP32) + T.gemm( + k_smem, + q_smem, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=True, + ) + + for i_h, i3_n in T.Parallel(h, blk_n2): + logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h] + + logits_sum = T.alloc_fragment(blk_n2, FP32) + T.reduce_sum(logits, logits_sum, dim=1) + + for i3_n in T.Parallel(blk_n2): + logits_sum[i3_n] *= k_s_frag[i3_n] + + T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2]) + + return fp8_index_kernel_ + + +def fp8_index( + q: torch.Tensor, + q_s: torch.Tensor, + k: torch.Tensor, + k_s: torch.Tensor, +) -> torch.Tensor: + """ + Perform index score using FP8 precision. + + Args: + q (torch.Tensor): The Q tensor, must be contiguous. + q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous. + k (torch.Tensor): The K tensor, must be contiguous. + k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous. + + fp8 q @ fp8 k -> fp32 logits + relu(fp32 logits) * q_s (weights) -> fp32 logits + fp32 logits -> fp32 logits_sum + fp32 logits_sum * k_s (e8m0) -> fp32 index_score + """ + return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s) diff --git a/python/models/deepseek_v3_2/temp_var_indices.py b/python/models/deepseek_v3_2/temp_var_indices.py new file mode 100644 index 0000000..552fa3f --- /dev/null +++ b/python/models/deepseek_v3_2/temp_var_indices.py @@ -0,0 +1,122 @@ +"""Named indices for DSA temporary variables. + +Mirrors the C++ ``DsaTempVars`` constants defined in +``include/lib/models/deepseek_v3_2/helper.hpp`` so that Python code can +reference temp_vars by name instead of magic numbers. + +Usage:: + + from tilert.models.deepseek_v3_2.temp_var_indices import Idx + + token_out = intermediates[Idx.TOKEN_OUT] # equivalent to intermediates[25] +""" + +from enum import IntEnum + + +class DsaTempVarIdx(IntEnum): + """Index constants for DSA temp_vars, mirroring C++ DsaTempVars.""" + + Q = 0 + KV = 1 + KI = 2 + Q_NOPE_DOWN = 3 + Q_PE = 4 + IQ = 5 + IQ_RT = 6 + IDX_SCORES = 7 + IDX_LOGITS = 8 + IDX_SELECTS = 9 + Q_NOPE = 10 + O = 11 # noqa: E741 — mirrors C++ DsaTempVars::O + O_ACC = 12 + O_LSE = 13 + O_LSE_ACC = 14 + PROJ_O = 15 + UNPROJ_O = 16 + SCORES = 17 + X_MLP_IN = 18 + UP_GATE = 19 + SEL_PROBS = 20 + SEL_INDICES = 21 + EXP_OUT = 22 + X_RMSNORM = 23 + LOGITS_OUT = 24 + TOKEN_OUT = 25 + EMBEDDING_RMSNORM = 26 + HIDDEN_RMSNORM = 27 + EH_PROJ = 28 + X_TENSOR = 29 + ROPE_FREQS = 30 + CUR_POS = 31 + TOKEN_ID = 32 + LAST_HIDDEN_STATES = 33 + DRAFT_TOKENS = 34 + PREDICTED_TOKENS = 35 + PREDICTED_HIDDEN = 36 + ACCEPTED_TOKENS = 37 + NEXT_DRAFT_TOKENS = 38 + X_QUANT = 39 + X_SCALE = 40 + MOE_UP_GATE = 41 + IDX_SEL_WS = 42 + MTP0_TOKEN_OUT = 43 + MTP1_TOKEN_OUT = 44 + MTP0_EXP_OUT = 45 + SAMPLING_SEED = 46 + SAMPLING_POSITIONS = 47 + SAMPLING_CONFIG = 48 + TOP_P_SCORES = 49 + TOP_P_DEBUG = 50 + + +# Sentinel: total number of temp vars. Must equal C++ DsaTempVars::temp_vars_size. +TEMP_VARS_SIZE = 51 + +# Short alias for convenient access +Idx = DsaTempVarIdx + + +def validate_temp_vars_layout() -> None: + """Validate that the Python enum matches the C++ DsaTempVars layout. + + Checks: + 1. Enum member count equals TEMP_VARS_SIZE. + 2. Indices are contiguous 0..TEMP_VARS_SIZE-1 with no gaps or duplicates. + 3. (If libtilert.so is loaded) C++ temp_vars_size matches Python TEMP_VARS_SIZE. + + Raises: + RuntimeError: If any validation check fails. + """ + members = list(DsaTempVarIdx) + + # Check member count + if len(members) != TEMP_VARS_SIZE: + raise RuntimeError( + f"DsaTempVarIdx has {len(members)} members but TEMP_VARS_SIZE={TEMP_VARS_SIZE}" + ) + + # Check contiguous indices + indices = sorted(m.value for m in members) + expected = list(range(TEMP_VARS_SIZE)) + if indices != expected: + missing = set(expected) - set(indices) + dupes = [i for i in indices if indices.count(i) > 1] + raise RuntimeError( + f"DsaTempVarIdx indices are not contiguous 0..{TEMP_VARS_SIZE - 1}. " + f"Missing: {missing}, Duplicates: {set(dupes)}" + ) + + # Check against C++ if the library is loaded + try: + import torch + + cpp_size = torch.ops.tilert.dsa_temp_vars_size() + if cpp_size != TEMP_VARS_SIZE: + raise RuntimeError( + f"Python TEMP_VARS_SIZE={TEMP_VARS_SIZE} != " + f"C++ DsaTempVars::temp_vars_size={cpp_size}" + ) + except (AttributeError, RuntimeError): + # Library not loaded or op not available — skip C++ check + pass diff --git a/python/models/glm_5/__init__.py b/python/models/glm_5/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/models/glm_5/generator.py b/python/models/glm_5/generator.py new file mode 100644 index 0000000..ed7fac0 --- /dev/null +++ b/python/models/glm_5/generator.py @@ -0,0 +1,584 @@ +"""DSA show hands for GLM5.""" + +import os +import time + +import torch +from transformers import AutoTokenizer + +from tilert import logger +from tilert.models.deepseek_v3_2.generator import stats_time +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.models.deepseek_v3_2.modules.end2end import ShowHandsDSALayer +from tilert.models.deepseek_v3_2.temp_var_indices import Idx +from tilert.tilert_init import tilert_init + +__all__ = [ + "GLM5Generator", +] + + +class GLM5Generator: + """Show hands generator for GLM5.""" + + def __init__( + self, + model_args: ModelArgs, + max_new_tokens: int = 100, + temperature: float = 1.0, + model_weights_dir: str = "", + with_mtp: bool = False, + top_p: float = 0.9, + top_k: int = 256, + use_topp: bool = False, + enable_thinking: bool = False, + sampling_seed: int = 42, + ): + """Initialize the ShowHandsGeneratorGlm5. + + Args: + max_new_tokens: Maximum number of new tokens to generate. Defaults to 100. + temperature: Temperature for sampling. Defaults to 1.0. + model_weights_dir: Path of the model weights directory. + with_mtp: Whether to use MTP (Multi-Token Prediction) for speculative decoding. + top_p: Top-p (nucleus) sampling threshold. Defaults to 0.9. + top_k: Top-k sampling threshold. Defaults to 256. + use_topp: Whether to use top-p sampling. Defaults to False (top-1 argmax). + enable_thinking: Whether to enable thinking mode in chat template. + """ + torch.set_num_threads(64) + self.model_weights_dir = model_weights_dir + + self.max_new_tokens = max_new_tokens + self.temperature = temperature + self.with_mtp = with_mtp + self.enable_thinking = enable_thinking + self.sampling_seed = sampling_seed + + self.config = model_args + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_weights_dir, trust_remote_code=True + ) # nosec B615 + jinja_file_path = os.path.join(self.model_weights_dir, "chat_template.jinja") + with open(jinja_file_path, encoding="utf-8") as f: + chat_template = f.read() + self.tokenizer.chat_template = chat_template + self.eos_id = self.tokenizer.eos_token_id + self.batch_size = 1 # fixed batch size to 1 for now + self.mtp_seq_len = 4 + + # GLM5 uses multiple stop tokens + self.stop_tokens = ["<|user|>", "<|endoftext|>", "<|observation|>", "<|assistant|>"] + self.stop_token_ids: set[int] = set() + for token in self.stop_tokens: + token_ids = self.tokenizer.encode(token, add_special_tokens=False) + if len(token_ids) == 1: + self.stop_token_ids.add(token_ids[0]) + else: + # Try to get from added_tokens_encoder + if ( + hasattr(self.tokenizer, "added_tokens_encoder") + and token in self.tokenizer.added_tokens_encoder + ): + self.stop_token_ids.add(self.tokenizer.added_tokens_encoder[token]) + # Always include eos_id + if self.eos_id is not None: + self.stop_token_ids.add(self.eos_id) + logger.info(f"Stop token IDs: {self.stop_token_ids}") + + self.default_device = torch.device("cuda:0") + + self.decode_layer = ShowHandsDSALayer( + model_args=self.config, + model_path=self.model_weights_dir, + with_mtp=with_mtp, + top_p=top_p, + top_k=top_k, + use_topp=use_topp, + ) + + def init(self) -> None: + """Initialize the ShowHandsGeneratorGlm5.""" + tilert_init() + + def cleanup(self) -> None: + """Cleanup the ShowHandsGeneratorGlm5.""" + self.decode_layer.cleanup() + + def init_random_weights(self) -> None: + """Random initialize the weights.""" + self.decode_layer.init_random_weights() + + def from_pretrained(self) -> None: + """Load the model weights from the given path.""" + self.decode_layer.from_pretrained(self.model_weights_dir) + + def update_sampling_params( + self, + temperature: float = 1.0, + top_p: float = 0.95, + top_k: int = 256, + use_topp: bool = True, + ) -> None: + """Update sampling parameters for the next generation. + + Updates both the Python attributes and the CUDA sampling_config tensor + that the TileRT kernel reads during forward pass. + """ + self.temperature = temperature + self.decode_layer.update_sampling_config( + temperature=temperature, top_p=top_p, top_k=top_k, use_topp=use_topp + ) + + @torch.inference_mode() + def generate( + self, + prompt: str, + print_log: bool = True, + with_mtp: bool | None = None, + prompt_tokens: list[int] | None = None, + ) -> tuple[str, list[float], list[int]]: + """Main function to load the model and perform single sequence generation. + + Args: + prompt: The input prompt string. + print_log: Whether to print generation logs. + with_mtp: Override MTP mode for this call. None uses self.with_mtp. + Requires MTP weights to have been loaded (self.with_mtp=True). + prompt_tokens: Pre-tokenized prompt tokens. If provided, skip tokenization + and use these tokens directly (useful for exact-length benchmarking). + + Returns: + Tuple of (result_text, time_list, accepted_counts). + accepted_counts is empty for non-MTP mode. + """ + active_mtp = with_mtp if with_mtp is not None else self.with_mtp + if active_mtp and not self.with_mtp: + raise ValueError("Cannot use MTP mode: MTP weights were not loaded") + self.decode_layer.set_sampling_seed(self.sampling_seed, with_mtp=active_mtp) + if active_mtp: + return self._generate_with_mtp(prompt, print_log, prompt_tokens=prompt_tokens) + result, time_list = self._generate_without_mtp( + prompt, print_log, with_mtp=active_mtp, prompt_tokens=prompt_tokens + ) + return result, time_list, [] # Empty accepted_counts for non-MTP + + def _generate_without_mtp( + self, + prompt: str, + print_log: bool = True, + with_mtp: bool = False, + prompt_tokens: list[int] | None = None, + ) -> tuple[str, list[float]]: + """Standard generation without MTP.""" + if prompt_tokens is None: + messages = [{"role": "user", "content": prompt}] + prompt_tokens = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + enable_thinking=self.enable_thinking, + ) + + max_seq_len = self.config.max_seq_len + prompt_len = len(prompt_tokens) + total_len = min(max_seq_len, self.max_new_tokens + prompt_len) + + tokens = torch.full( + (self.batch_size, total_len), -1, dtype=torch.long, device=self.default_device + ) + tokens[0, :prompt_len] = torch.tensor( + prompt_tokens, dtype=torch.long, device=self.default_device + ) + prompt_mask = tokens != -1 + + prev_pos = 0 + finished = torch.tensor( + [False] * self.batch_size, dtype=torch.bool, device=self.default_device + ) + + time_list = [] + for cur_pos_val in range(1, total_len): + start_time = time.time() + multi_devices_results = self.decode_layer.forward( + tokens[0, prev_pos], with_mtp=with_mtp + ) + end_time = time.time() + time_list.append(end_time - start_time) + + intermediates, *_ = multi_devices_results[0] + next_token = intermediates[Idx.TOKEN_OUT][0][0] # only the first token + + # replace the next token with the prompt token if the prompt mask is True + next_token = torch.where( + prompt_mask[0, cur_pos_val], tokens[0, cur_pos_val], next_token + ) + tokens[0, cur_pos_val] = next_token + # Check if next_token is any of the stop tokens + is_stop_token = next_token.item() in self.stop_token_ids + finished |= torch.logical_and( + ~prompt_mask[0, cur_pos_val], + torch.tensor(is_stop_token, dtype=torch.bool, device=self.default_device), + ) + prev_pos = cur_pos_val + if cur_pos_val >= prompt_len: + decoded_tokens = self.tokenizer.decode( + [next_token.item()], skip_special_tokens=True + ) + if print_log: + print(decoded_tokens, end="", flush=True) + + if finished.all(): + break + + if print_log: + print("\n") + logger.info(f"--Number of tokens generated: {len(time_list)}") + + stats_time(time_list, "==== Performance ====") + print("\n") + + # Reset sequence after generation, i.e. reset the cur_pos to 0 internally + self.decode_layer.reset_sequence() + + completion_tokens = [] + for _, toks in enumerate(tokens.tolist()): + toks = toks[prompt_len : prompt_len + self.max_new_tokens] + # Find first stop token and truncate + stop_idx = len(toks) + for i, tok in enumerate(toks): + if tok in self.stop_token_ids: + stop_idx = i + break + toks = toks[:stop_idx] + completion_tokens.append(toks) + + decoded_tokens = self.tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) + + return f"{decoded_tokens[0]}\n" if decoded_tokens else "", time_list + + def _generate_with_mtp( + self, + prompt: str, + print_log: bool = True, + prompt_tokens: list[int] | None = None, + ) -> tuple[str, list[float], list[int]]: + """Generation with MTP (Multi-Token Prediction) speculative decoding.""" + if prompt_tokens is None: + prompt_tokens = self.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + enable_thinking=self.enable_thinking, + ) + + max_seq_len = self.config.max_seq_len + prompt_len = len(prompt_tokens) + total_len = min(max_seq_len, self.max_new_tokens + prompt_len) + + # Output tokens buffer + tokens = torch.full( + (self.batch_size, total_len), -1, dtype=torch.long, device=self.default_device + ) + tokens[0, :prompt_len] = torch.tensor( + prompt_tokens, dtype=torch.long, device=self.default_device + ) + + prefill_time_list = [] + decode_time_list = [] + decode_accepted_counts = [] # Only track decode phase for statistics + cur_pos = 0 # Current position in the output sequence + + # Prefill phase: process prompt tokens in non-overlapping chunks. + # Each chunk fills unique KV cache positions for both main model and MTP[0]. + while cur_pos < prompt_len - 1: + draft_end = min(cur_pos + self.mtp_seq_len, prompt_len) + draft_tokens = tokens[0, cur_pos:draft_end].clone() + actual_token_count = draft_tokens.shape[0] + + # Pad if needed (use last token for padding) + if actual_token_count < self.mtp_seq_len: + pad_token = draft_tokens[-1].item() + padding = torch.full( + (self.mtp_seq_len - actual_token_count,), + pad_token, + dtype=torch.long, + device=self.default_device, + ) + draft_tokens = torch.cat([draft_tokens, padding]) + + draft_tokens = draft_tokens.reshape(1, self.mtp_seq_len).to(torch.int32) + + # Provide the extra token for MTP[0]'s shifted input last position. + # MTP[0] needs tokens[cur_pos+1 : cur_pos+mtp_seq_len+1], so the + # extra token is at cur_pos + mtp_seq_len. + mtp_extra_pos = cur_pos + self.mtp_seq_len + if mtp_extra_pos < prompt_len: + mtp_extra_token = int(tokens[0, mtp_extra_pos].item()) + else: + # Beyond prompt — use last valid draft token as padding + mtp_extra_token = int(tokens[0, draft_end - 1].item()) + self.decode_layer.set_prefill_mtp_extra_token(mtp_extra_token) + + # Tell GPU how many tokens are valid (for cur_pos advancement) + self.decode_layer.set_prefill_valid_tokens(actual_token_count) + + start_time = time.time() + self.decode_layer.forward(draft_tokens, with_mtp=True) + end_time = time.time() + prefill_time_list.append(end_time - start_time) + + # No overlap: advance by the full actual_token_count + cur_pos += actual_token_count + + # After no-overlap prefill, cur_pos may have overshot to prompt_len. + # Reset to prompt_len - 1 for correct decode start (first decode + # reprocesses the last prompt token position). + cur_pos = prompt_len - 1 + self.set_cur_pos(prompt_len - 1) + + # Decode phase: speculative decoding + # Set prefill_valid_tokens to 0 to switch to decode mode + self.decode_layer.set_prefill_valid_tokens(0) + + finished = False + while cur_pos < total_len - 1 and not finished: + # Get next_draft_tokens from previous iteration + # (or use last prompt tokens for first decode) + if cur_pos == prompt_len - 1: + # First decode iteration: use last prompt token repeated as placeholder drafts + # We can't use [t6, t7, t8, t9] because that would apply wrong RoPE positions + # (cur_pos=9 means positions 9,10,11,12, but t6 should be at position 6) + last_token = tokens[0, prompt_len - 1].item() + draft_tokens = torch.full( + (self.mtp_seq_len,), + last_token, + dtype=torch.long, + device=self.default_device, + ) + draft_tokens = draft_tokens.reshape(1, self.mtp_seq_len).to(torch.int32) + else: + # Use next_draft_tokens from previous iteration + draft_tokens = self.decode_layer.get_next_draft_tokens(0).reshape( + 1, self.mtp_seq_len + ) + + start_time = time.time() + self.decode_layer.forward(draft_tokens, with_mtp=True) + end_time = time.time() + decode_time_list.append(end_time - start_time) + + num_accepted = self.decode_layer.get_num_accepted(0) + # Use predicted_tokens for output (not next_draft_tokens which is for next iteration) + predicted_tokens = self.decode_layer.get_predicted_tokens(0).flatten() + decode_accepted_counts.append(num_accepted) + + # Add accepted tokens to output + num_output_tokens = num_accepted + for i in range(num_output_tokens): + if cur_pos + 1 + i >= total_len: + break + new_token = int(predicted_tokens[i].item()) + tokens[0, cur_pos + 1 + i] = new_token + + # Print generated token + if cur_pos + 1 + i >= prompt_len and print_log: + decoded_text = self.tokenizer.decode([new_token], skip_special_tokens=True) + print(decoded_text, end="", flush=True) + + # Check for any stop token + if new_token in self.stop_token_ids: + finished = True + break + + cur_pos += num_accepted + + if print_log: + print("\n") + total_tokens = sum(decode_accepted_counts) + logger.info(f"--Number of forward calls (decode): {len(decode_accepted_counts)}") + logger.info(f"--Total tokens generated: {total_tokens}") + if len(decode_accepted_counts) > 0: + avg_accepted = sum(decode_accepted_counts) / len(decode_accepted_counts) + min_accepted = min(decode_accepted_counts) + max_accepted = max(decode_accepted_counts) + logger.info( + f"--Accepted tokens per call: mean={avg_accepted:.2f}, " + f"min={min_accepted}, max={max_accepted}" + ) + + # Calculate correct TPS accounting for MTP's multiple tokens per call + if decode_time_list: + total_decode_time = sum(decode_time_list) + effective_tps = total_tokens / total_decode_time if total_decode_time > 0 else 0 + avg_time_ms = total_decode_time / len(decode_time_list) * 1000 + logger.info( + f"--Avg forward time: {avg_time_ms:.2f}ms, " + + f"({1000 / avg_time_ms:.2f} forwards/s)" + ) + logger.info(f"--Effective TPS (with MTP): {effective_tps:.2f} tokens/s") + + print("\n") + + # Reset sequence after generation + self.decode_layer.reset_sequence() + + # Extract completion tokens + completion_tokens = [] + for _, toks in enumerate(tokens.tolist()): + toks = toks[prompt_len : prompt_len + self.max_new_tokens] + # Remove -1 padding and tokens after any stop token + toks = [t for t in toks if t != -1] + # Find first stop token and truncate + stop_idx = len(toks) + for i, tok in enumerate(toks): + if tok in self.stop_token_ids: + stop_idx = i + break + toks = toks[:stop_idx] + completion_tokens.append(toks) + + decoded_tokens = self.tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) + + return ( + f"{decoded_tokens[0]}\n" if decoded_tokens else "", + decode_time_list, + decode_accepted_counts, + ) + + def inject_cache( + self, + layer_caches: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], + start_pos: int = 0, + end_pos: int | None = None, + ) -> None: + """Inject external cache data into TileRT for P/D separation. + + This API allows injecting pre-computed KI/KV/PE cache data from an external + prefill system (e.g., SGLang), enabling prefill-decode disaggregation. + + Args: + layer_caches: List of (ki, kv, pe) tuples for each layer (0 to NUM_LAYERS-1). + Each tensor should be BF16 with shape [seqlen, dim] where: + - ki: [seqlen, 128] - compressed key (index_head_dim) + - kv: [seqlen, 512] - compressed key-value (kv_lora_rank) + - pe: [seqlen, 64] - position encoding cache (qk_rope_head_dim) + start_pos: Start position in cache to write (0-indexed). Defaults to 0. + end_pos: End position in cache (exclusive). If None, uses seqlen from tensors. + + Example: + >>> # Load cache from external prefill system + >>> layer_caches = [] # List of 78 (ki, kv, pe) tuples for GLM-5 + >>> for layer_id in range(78): + ... ki = load_ki_for_layer(layer_id) # [seqlen, 128] bf16 + ... kv = load_kv_for_layer(layer_id) # [seqlen, 512] bf16 + ... pe = load_pe_for_layer(layer_id) # [seqlen, 64] bf16 + ... layer_caches.append((ki, kv, pe)) + >>> generator.inject_cache(layer_caches, start_pos=0) + >>> generator.set_cur_pos(seqlen) # Set RoPE position + >>> # Continue generation from cache + """ + num_layers = len(layer_caches) + if num_layers == 0: + logger.warning("inject_cache called with empty layer_caches") + return + + # Infer seqlen from first tensor if end_pos not specified + first_ki, _, _ = layer_caches[0] + seqlen = first_ki.size(0) + if end_pos is None: + end_pos = start_pos + seqlen + + cache_len = end_pos - start_pos + logger.info(f"Injecting cache: {num_layers} layers, positions [{start_pos}, {end_pos})") + + num_devices = self.decode_layer.num_devices + + for device_id in range(num_devices): + _, caches, _, _ = self.decode_layer._get_device_result(device_id) + + for layer_id, (ki, kv, pe) in enumerate(layer_caches): + if layer_id >= num_layers: + logger.warning(f"Layer index {layer_id} is out of bounds, skipping.") + break + + # GLM-5 cache layout: 3 tensors per layer (ki, kv, pe) + # Based on CacheVarsGlm5: k_cache, kv_cache, pe_cache + base_idx = layer_id * 3 + + # Copy to device and inject into cache + # Cache layout: [batch=1, max_seq_len, dim] + # External data: [seqlen, dim] + ki_src = ki[:cache_len].to(f"cuda:{device_id}") + kv_src = kv[:cache_len].to(f"cuda:{device_id}") + pe_src = pe[:cache_len].to(f"cuda:{device_id}") + + caches[base_idx + 0][0, start_pos:end_pos, :].copy_(ki_src) + caches[base_idx + 1][0, start_pos:end_pos, :].copy_(kv_src) + caches[base_idx + 2][0, start_pos:end_pos, :].copy_(pe_src) + + logger.info(f"Cache injection completed for {num_devices} devices") + + def set_cur_pos(self, cur_pos: int) -> None: + """Set the current position for RoPE in C++ backend. + + This should be called after inject_cache() to ensure the C++ global + g_cur_pos matches the injected cache length. This is critical for + correct RoPE position encoding during continued generation. + + For MTP mode, sets the GPU tensor at intermediates[31] directly. + For non-MTP mode, calls the C++ dsa_show_hands_set_cur_pos_glm5 API. + + Args: + cur_pos: The current sequence position (typically the length of prefilled tokens). + + Example: + >>> generator.inject_cache(layer_caches, start_pos=0) + >>> generator.set_cur_pos(prefill_len) # Set position to prefill length + >>> # Now generate continues from the correct position + """ + if self.with_mtp: + # MTP E2E uses cur_pos tensor in TempVars + num_devices = self.decode_layer.num_devices + for device_id in range(num_devices): + intermediates, _, _, _ = self.decode_layer._get_device_result(device_id) + cur_pos_tensor = intermediates[Idx.CUR_POS] + cur_pos_tensor.fill_(cur_pos) + else: + # Non-MTP uses the C++ global g_cur_pos + torch.ops.tilert.dsa_show_hands_set_cur_pos_glm5(cur_pos) + logger.info(f"Set cur_pos to {cur_pos}") + + def inject_last_hidden_state(self, last_hidden_state: torch.Tensor) -> None: + """Inject the last hidden state for MTP mode. + + For MTP (Multi-Token Prediction), the MTP preprocess layer needs the + last hidden state from the main model's last token. This method injects + the hidden state into intermediates[33] (last_hidden_states slot). + + Args: + last_hidden_state: [hidden_size] or [1, hidden_size] BF16 tensor. + The hidden state of the last token from prefill. + + Example: + >>> # After inject_cache, inject the last hidden state for MTP + >>> generator.inject_last_hidden_state(last_hidden_state) + >>> generator.set_cur_pos(prefill_len) + >>> # Then start generation + """ + if not self.with_mtp: + logger.warning("inject_last_hidden_state called but with_mtp is False, skipping") + return + + # Normalize shape to [1, hidden_size] + if last_hidden_state.dim() == 1: + last_hidden_state = last_hidden_state.unsqueeze(0) + + num_devices = self.decode_layer.num_devices + for device_id in range(num_devices): + intermediates, _, _, _ = self.decode_layer._get_device_result(device_id) + # Shape: [batch=1, seq=4, hidden_size], we set seq[0] since it's the last token + lhs_tensor = intermediates[Idx.LAST_HIDDEN_STATES] + lhs_src = last_hidden_state.to(f"cuda:{device_id}") + lhs_tensor[0, 0, :].copy_(lhs_src.squeeze(0)) + + logger.info(f"Injected last_hidden_state to {num_devices} devices") diff --git a/python/models/glm_5/model_args.py b/python/models/glm_5/model_args.py new file mode 100644 index 0000000..a64ed6f --- /dev/null +++ b/python/models/glm_5/model_args.py @@ -0,0 +1,100 @@ +"""Model arguments and hyperparameters.""" + +from dataclasses import dataclass +from typing import Literal + +from tilert.models.deepseek_v3_2.model_args import ModelArgs + +__all__ = [ + "ModelArgsGLM5", +] + + +@dataclass +class ModelArgsGLM5(ModelArgs): + """ + Data class for defining model arguments and hyperparameters. + + Attributes: + arch_name (str): Architecture name. + max_batch_size (int): Maximum batch size. + max_seq_len (int): Maximum sequence length. + dtype (Literal["bf16", "fp8"]): Data type for computations. + scale_fmt (Optional[str]): Format for quantization scale. + vocab_size (int): Vocabulary size. + dim (int): Model dimension. + inter_dim (int): Intermediate dimension for MLP layers. + moe_inter_dim (int): Intermediate dimension for MoE layers. + n_layers (int): Number of transformer layers. + n_dense_layers (int): Number of dense layers in the model. + n_heads (int): Number of attention heads. + n_routed_experts (int): Number of routed experts for MoE layers. + n_shared_experts (int): Number of shared experts for MoE layers. + n_activated_experts (int): Number of activated experts in MoE layers. + n_expert_groups (int): Number of expert groups. + n_limited_groups (int): Number of limited groups for MoE routing. + score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing. + route_scale (float): Scaling factor for routing scores. + q_lora_rank (int): LoRA rank for query projections. + kv_lora_rank (int): LoRA rank for key-value projections. + qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. + qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. + v_head_dim (int): Dimension for value projections. + original_seq_len (Optional[int]): Original sequence length. + rope_theta (float): Base for rotary positional encoding. + rope_factor (Optional[float]): Scaling factor for extended sequence lengths. + beta_fast (Optional[int]): Fast beta correction factor. + beta_slow (Optional[int]): Slow beta correction factor. + mscale (float): Scaling factor for extended attention. + index_head_dim (int): Dimension for index head. + index_topk (int): Top-k for index head. + """ + + arch_name = "glm_5" + + max_batch_size: int = 1 # NOTE: the current implementation only supports a batch size being 1 + max_seq_len: int = 202752 + dtype: Literal["bf16", "fp8"] = "fp8" + scale_fmt: str | None = None + + vocab_size: int = 154880 + dim: int = 6144 + inter_dim: int = 12288 + moe_inter_dim: int = 2048 + n_layers: int = 78 + n_dense_layers: int = 3 + n_heads: int = 64 + + # moe + n_routed_experts: int = 256 + n_shared_experts: int = 1 + n_activated_experts: int = 8 + # n_expert_groups: int = 8 + # n_limited_groups: int = 4 + score_func: Literal["softmax", "sigmoid"] = "softmax" + route_scale: float = 2.5 + + # mla + q_lora_rank: int = 2048 + kv_lora_rank: int = 512 + qk_nope_head_dim: int = 192 + qk_rope_head_dim: int = 64 + v_head_dim: int = 256 + + # yarn + original_seq_len: int | None = None + rope_theta: float = 1000000.0 + rope_factor: float | None = None + beta_fast: int | None = None + beta_slow: int | None = None + mscale: float = 1.0 + + # index + index_n_heads: int = 32 + index_head_dim: int = 128 + index_topk: int = 2048 + + # quant + block_size: int = 128 + + eps: float = 1e-5 diff --git a/python/models/glm_5/params.py b/python/models/glm_5/params.py new file mode 100644 index 0000000..2721229 --- /dev/null +++ b/python/models/glm_5/params.py @@ -0,0 +1 @@ +"""GLM5 parameters and initializers.""" diff --git a/python/models/preprocess/__init__.py b/python/models/preprocess/__init__.py deleted file mode 100644 index 4bc42ac..0000000 --- a/python/models/preprocess/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Weight loading and preprocessing utilities.""" - -from tilert.models.preprocess.weight_utils import WeightLoader - -__all__ = [ - "WeightLoader", -] diff --git a/python/models/preprocess/weight_converter.py b/python/models/preprocess/weight_converter.py new file mode 100644 index 0000000..c0926aa --- /dev/null +++ b/python/models/preprocess/weight_converter.py @@ -0,0 +1,697 @@ +import json +import os +import pprint +from collections import OrderedDict +from typing import Any, TypedDict + +import torch +from safetensors.torch import load_file, save_file + +from tilert import logger +from tilert.models.deepseek_v3_2.model_args import ModelArgs +from tilert.models.deepseek_v3_2.model_args import ModelArgs as ModelArgsDsav32 +from tilert.models.deepseek_v3_2.modules.mla import Mla +from tilert.models.deepseek_v3_2.ops.down_allreduce import DownAllReduce +from tilert.models.deepseek_v3_2.ops.eh_proj_allreduce import EHProjAllReduce +from tilert.models.deepseek_v3_2.ops.expert_down_allreduce import ExpertDownAllReduce +from tilert.models.deepseek_v3_2.ops.expert_sel_up_gate_silu import ExpertSelectUpGateSiLU +from tilert.models.deepseek_v3_2.ops.rmsnorm_head_proj import RMSNormHeadProj +from tilert.models.deepseek_v3_2.ops.rmsnorm_up_gate_silu import RMSNormUpGateSiLU +from tilert.models.glm_5.model_args import ModelArgsGLM5 + +__all__ = [ + "WeightConverter", +] + + +class ShardInfo(TypedDict): + """Type definition for shard information.""" + + filename: str + tensors: list[str] + + +class WeightConverter: + """Weight converter for DeepSeek V3.2 model.""" + + def __init__( + self, + model_args: ModelArgs, + num_devices: int, + model_dir: str, + save_dir: str, + test_mode: bool = False, + ) -> None: + self.model_args = model_args + self.num_devices = num_devices + self.model_dir = model_dir + self.save_dir = save_dir + self.test_mode = test_mode + + self.num_dense_layers = model_args.n_dense_layers + self.num_moe_layers = model_args.n_layers - self.num_dense_layers + self.num_mtp_layers = 1 + self.total_layers = self.num_dense_layers + self.num_moe_layers + self.num_mtp_layers + if self.test_mode: + self.target_layers = [0, self.model_args.n_dense_layers, self.model_args.n_layers] + else: + self.target_layers = list(range(self.total_layers)) + + self.num_experts = model_args.n_routed_experts + + self.index_file = "model.safetensors.index.json" + self.__check_dir() + + # specially treated the embedding, norm, and head weights + # at the beginning and end of the model + self.emb_name = "model.embed_tokens.weight" + self.norm_name = "model.norm.weight" + self.head_name = "lm_head.weight" + self.special_treated_params: dict[str, str] = {} + + self.files_by_layers: dict[str, set[str]] = self.__group_by_layers() + self.default_device = "cpu" + + self.converted_weights_dict: dict[str, OrderedDict[str, torch.Tensor]] = {} + for i in range(self.num_devices): + self.converted_weights_dict[f"dev_{i}"] = OrderedDict() + + def __get_layer_num(self, param: str) -> int: + """Get layer number from parameter name.""" + if "layers" not in param: + return -1 + try: + return int(param.split(".")[2]) + except ValueError: + raise ValueError(f"Invalid file name: {param}") + + def __group_by_layers(self) -> dict[str, set[str]]: + """Load the index file.""" + with open(os.path.join(self.model_dir, self.index_file)) as f: + weight_map = json.load(f)["weight_map"] + + files_by_layers: dict[str, set[str]] = {} + for param, file_name in weight_map.items(): + layer_num = self.__get_layer_num(param) + if layer_num == -1: + logger.info(f"skip parameter {param} in {file_name}.") + self.special_treated_params[param] = file_name + continue + + key = f"layer_{layer_num}" + if key in files_by_layers: + files_by_layers[key].add(file_name) + else: + files_by_layers[key] = {file_name} + + return files_by_layers + + def __check_dir(self) -> None: + if not os.path.exists(self.model_dir): + raise ValueError(f"Model directory {self.model_dir} does not exist") + + if not os.path.exists(os.path.join(self.model_dir, self.index_file)): + raise ValueError(f"Index file {self.index_file} not found in {self.model_dir}") + + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) + + def get_tensor_size_bytes(self, tensor: torch.Tensor) -> int: + """Calculate the size of a tensor in bytes.""" + return int(tensor.numel() * tensor.element_size()) + + def parse_size(self, size_str: str) -> int: + """Parse size string like '1GB', '100MB' to bytes.""" + size_str = size_str.upper().strip() + if size_str.endswith("GB"): + return int(float(size_str[:-2]) * 1024 * 1024 * 1024) + if size_str.endswith("MB"): + return int(float(size_str[:-2]) * 1024 * 1024) + if size_str.endswith("KB"): + return int(float(size_str[:-2]) * 1024) + + return int(size_str) + + def save_file_sharded( + self, + weights_dict: dict[str, torch.Tensor], + base_filename: str, + max_shard_size: str = "4GB", + save_dir: str = "", + ) -> list[ShardInfo]: + """Save weights dictionary to multiple safetensors files. + + Each shard not exceeding max_shard_size. + + Args: + weights_dict: Dictionary of tensor names to tensors + base_filename: Base filename (e.g., "model.safetensors") + max_shard_size: Maximum size per shard (e.g., "1GB", "100MB") + save_dir: Directory to save the shards + """ + if save_dir: + base_filename = os.path.join(save_dir, base_filename) + + logger.info(f"Saving to safetensors format with max shard size {max_shard_size}...") + + max_size_bytes = self.parse_size(max_shard_size) + + tensor_nums = len(weights_dict) # placeholder for number sharded files + + shards: list[ShardInfo] = [] + current_shard: dict[str, torch.Tensor] = {} + current_size = 0 + shard_index = 1 # first shard is for embedding + + def get_shard_filename(shard_index: int) -> str: + return f"{base_filename}-{shard_index:05d}-of-{tensor_nums:05d}.safetensors" + + # Save embedding tensor to separate file + save_file(self.emb_weights_dict, get_shard_filename(shard_index)) + shards.append( + { + "filename": get_shard_filename(1), + "tensors": list(self.emb_weights_dict.keys()), + } + ) + + shard_index += 1 + for dev in weights_dict: + logger.info(f"Processing weights for device {dev}") + dev_tensors = weights_dict[dev] + + tensor_sizes = OrderedDict( + {name: self.get_tensor_size_bytes(tensor) for name, tensor in dev_tensors.items()} + ) + + # If adding this tensor would exceed max size, start a new shard + for tensor_name, tensor_size in tensor_sizes.items(): + if current_size + tensor_size > max_size_bytes and current_shard: + # Save current shard + shard_filename = get_shard_filename(shard_index) + logger.info(f"Saving shard {shard_index} to {shard_filename}") + save_file(current_shard, shard_filename) + + shards.append( + {"filename": shard_filename, "tensors": list(current_shard.keys())} + ) + current_shard = {} + current_size = 0 + shard_index += 1 + + # Add tensor to current shard + current_shard[tensor_name] = dev_tensors[tensor_name] + current_size += tensor_size + + # Save the last shard for the current device + if current_shard: + shard_filename = get_shard_filename(shard_index) + logger.info(f"Saving shard {shard_index} to {shard_filename}") + save_file(current_shard, shard_filename) + shards.append({"filename": shard_filename, "tensors": list(current_shard.keys())}) + current_shard = {} + current_size = 0 + shard_index += 1 + + # Update shard filenames with correct total count + total_shards = len(shards) + for i, shard in enumerate(shards, 1): + old_filename = shard["filename"] + new_filename = f"{base_filename}-{i:05d}-of-{total_shards:05d}.safetensors" + if old_filename != new_filename: + os.rename(old_filename, new_filename) + shard["filename"] = new_filename + + total_size = sum(self.get_tensor_size_bytes(t) for t in self.emb_weights_dict.values()) + for dev in weights_dict: + dev_tensors = weights_dict[dev] + tensor_sizes = OrderedDict( + {name: self.get_tensor_size_bytes(tensor) for name, tensor in dev_tensors.items()} + ) + total_size += sum(tensor_sizes.values()) + + index: dict[str, Any] = { + "metadata": {"total_size": total_size}, + "weight_map": {}, + } + + weight_map: dict[str, str] = index["weight_map"] # type: ignore[assignment] + for shard in shards: + for tensor_name in shard["tensors"]: + weight_map[tensor_name] = os.path.basename(shard["filename"]) + + index_filename = f"{base_filename}.index.json" + with open(index_filename, "w") as f: + json.dump(index, f, indent=2) + + logger.info(f"Saved {total_shards} shard(s) with max size {max_shard_size}") + logger.info(f"Index file: {index_filename}") + return shards + + def transform_mla( + self, + weights_hf: dict[str, torch.Tensor], + layer_id: int, + ) -> dict[str, dict[str, torch.Tensor]]: + mla_weights_map: dict[str, dict[str, torch.Tensor]] = {} + for dev_id in range(self.num_devices): + mla_weights_map.setdefault(f"dev_{dev_id}", {}) + mla = Mla(self.model_args, device_id=0, num_devices=self.num_devices) + mla_raw_dict = { + _k: weights_hf[f"model.layers.{layer_id}.{_k}"] for _k in mla.get_ref_weights_alias() + } + mla_sharded_dict = mla.device_sharding(mla_raw_dict) + for dev_id in range(self.num_devices): + for key, value in mla_sharded_dict.items(): + mla_weights_map[f"dev_{dev_id}"].update({key: value[dev_id].contiguous()}) + + mla_weights = {} + for dev_id in range(self.num_devices): + mla_weights_dev = {} + for key in mla_weights_map[f"dev_{dev_id}"].keys(): + mla_weights_dev.update({key: mla_weights_map[f"dev_{dev_id}"][key]}) + mla_weights.update({f"dev_{dev_id}": mla_weights_dev}) + + return mla_weights + + def transform_moe( + self, + weights_hf: dict[str, torch.Tensor], + layer_id: int, + ) -> dict[str, dict[str, torch.Tensor]]: + post_attn_norm_weight = f"model.layers.{layer_id}.post_attention_layernorm.weight" + mlp_gate_weight = f"model.layers.{layer_id}.mlp.gate.weight" + post_attn_norm_weight = weights_hf[post_attn_norm_weight].float() + mlp_gate_weight = weights_hf[mlp_gate_weight] + + moe_weights: dict[str, dict[str, torch.Tensor]] = {} + exp_sel_up_gate_silu = ExpertSelectUpGateSiLU(self.model_args, self.num_devices) + ref_scope = f"model.layers.{layer_id}." + exp_weights_map = { + k: weights_hf[ref_scope + k] for k in exp_sel_up_gate_silu.ref_weights_alias() + } + exp_sharded = exp_sel_up_gate_silu.device_sharding(exp_weights_map) + tilert_alias = exp_sel_up_gate_silu.tilert_weights_alias + exp_bias = exp_sharded[tilert_alias.exp_bias] + exp_gate_weights = exp_sharded[tilert_alias.exp_gate_weights] + exp_gate_scales = exp_sharded[tilert_alias.exp_gate_scales] + exp_up_weights = exp_sharded[tilert_alias.exp_up_weights] + exp_up_scales = exp_sharded[tilert_alias.exp_up_scales] + exp_down_allreduce = ExpertDownAllReduce( + self.model_args, device_id=0, num_devices=self.num_devices + ) + exp_down_weights, exp_down_scales = exp_down_allreduce.device_sharding( + weights_hf, f"model.layers.{layer_id}.mlp" + ) + for dev_id in range(self.num_devices): + key = f"dev_{dev_id}" + moe_weights.update( + { + key: { + "unproj_o_gamma": post_attn_norm_weight, + "exp_proj_weights": mlp_gate_weight, + "exp_bias": exp_bias[dev_id], + "exp_gate_weights": exp_gate_weights[dev_id], + "exp_gate_scales": exp_gate_scales[dev_id], + "exp_up_weights": exp_up_weights[dev_id], + "exp_up_scales": exp_up_scales[dev_id], + "exp_down_weights": exp_down_weights[dev_id], + "exp_down_scales": exp_down_scales[dev_id], + } + } + ) + return moe_weights + + def transform_mlp( + self, + weights_hf: dict[str, torch.Tensor], + layer_id: int, + ) -> dict[str, dict[str, torch.Tensor]]: + """Transform MLP weights.""" + print(RMSNormUpGateSiLU) + rmsnorm_up_gate_silu = RMSNormUpGateSiLU( + self.model_args, device_id=0, num_devices=self.num_devices + ) + post_attn_norm_weight, gate_weights, gate_scales, up_weights, up_scales = ( + rmsnorm_up_gate_silu.device_sharding(weights_hf, f"model.layers.{layer_id}.mlp") + ) + down_allreduce = DownAllReduce(self.model_args, device_id=0, num_devices=self.num_devices) + down_weights, down_scales = down_allreduce.device_sharding( + weights_hf, f"model.layers.{layer_id}.mlp" + ) + + weights_unproj_o_gamma: dict[str, dict[str, torch.Tensor]] = {} + for dev_id in range(self.num_devices): + weights_unproj_o_gamma[f"dev_{dev_id}"] = { + "unproj_o_gamma": post_attn_norm_weight[dev_id] + } + + weights_upgate: dict[str, dict[str, torch.Tensor]] = {} + for dev_id in range(self.num_devices): + weights_upgate.update( + { + f"dev_{dev_id}": { + "gate_weights": gate_weights[dev_id], + "gate_scales": gate_scales[dev_id], + "up_weights": up_weights[dev_id], + "up_scales": up_scales[dev_id], + } + } + ) + + weights_down: dict[str, dict[str, torch.Tensor]] = {} + for dev_id in range(self.num_devices): + weights_down.update( + { + f"dev_{dev_id}": { + "down_weights": down_weights[dev_id], + "down_scales": down_scales[dev_id], + } + } + ) + + mlp_weights: dict[str, dict[str, torch.Tensor]] = {} + for dev_id in range(self.num_devices): + mlp_weights_dev: dict[str, torch.Tensor] = {} + mlp_weights_dev.update(weights_unproj_o_gamma[f"dev_{dev_id}"]) + mlp_weights_dev.update(weights_upgate[f"dev_{dev_id}"]) + mlp_weights_dev.update(weights_down[f"dev_{dev_id}"]) + mlp_weights[f"dev_{dev_id}"] = mlp_weights_dev + return mlp_weights + + def transform_mtp( + self, + weights_hf: dict[str, torch.Tensor], + layer_id: int, + ) -> dict[str, dict[str, torch.Tensor]]: + """Transform MTP weights. + + Transformations applied: + - enorm.weight: Direct use (fp32) + - hnorm.weight: Direct use (fp32) + - eh_proj.weight: Split along dim 1, reshape [7168, 1792] -> [128, 7, 56, 256] + """ + enorm_weight_key = f"model.layers.{layer_id}.enorm.weight" + hnorm_weight_key = f"model.layers.{layer_id}.hnorm.weight" + enorm_weight = weights_hf[enorm_weight_key] + hnorm_weight = weights_hf[hnorm_weight_key] + + eh_proj_allreduce = EHProjAllReduce(self.model_args, self.num_devices) + (eh_proj_weights,) = eh_proj_allreduce.device_sharding( + weights_hf, f"model.layers.{layer_id}" + ) + + return { + f"dev_{dev_id}": { + "embedding_rmsnorm_gamma": enorm_weight, + "hidden_rmsnorm_gamma": hnorm_weight, + "eh_proj_weights": eh_proj_weights[dev_id], + } + for dev_id in range(self.num_devices) + } + + def convert_a_layer(self, layer_idx: int) -> tuple[ + dict[str, dict[str, torch.Tensor]], + dict[str, dict[str, torch.Tensor]], + dict[str, dict[str, torch.Tensor]], + ]: + assert layer_idx < self.total_layers + + key = f"layer_{layer_idx}" + files_to_load = self.files_by_layers[key] + + weights_dict = {} + for file_name in files_to_load: + logger.info(f"Loading weight from {file_name}") + path = os.path.join(self.model_dir, file_name) + weights = load_file(path, device=self.default_device) + weights_dict.update(weights) + + mla_weights = self.transform_mla(weights_dict, layer_idx) + + if layer_idx < self.num_dense_layers: + mlp_weights = self.transform_mlp(weights_dict, layer_idx) + else: + mlp_weights = self.transform_moe(weights_dict, layer_idx) + + mtp_weights: dict[str, dict[str, torch.Tensor]] = { + f"dev_{dev_id}": {} for dev_id in range(self.num_devices) + } + if layer_idx >= self.num_dense_layers + self.num_moe_layers: + mtp_weights = self.transform_mtp(weights_dict, layer_idx) + + return mla_weights, mlp_weights, mtp_weights + + def __process_head_weights(self) -> None: + """Process head weights.""" + head_weight_file = self.special_treated_params[self.head_name] + head_weight_file = os.path.join(self.model_dir, head_weight_file) + head_weights = load_file(head_weight_file, device=self.default_device)[self.head_name] + + norm_weight_file = self.special_treated_params[self.norm_name] + norm_weight_file = os.path.join(self.model_dir, norm_weight_file) + norm_weights = load_file(norm_weight_file, device=self.default_device)[self.norm_name] + + weights_hf = { + "model.norm.weight": norm_weights, + "lm_head.weight": head_weights, + } + + layer_idx = self.num_dense_layers + self.num_moe_layers + rmsnorm_head_proj = RMSNormHeadProj( + self.model_args, device_id=0, num_devices=self.num_devices + ) + gamma, head_proj = rmsnorm_head_proj.device_sharding(weights_hf) + + for dev_id in range(self.num_devices): + self.converted_weights_dict[f"dev_{dev_id}"][ + f"layer_{layer_idx}_lm_head.weight_dev_{dev_id}" + ] = head_proj[dev_id] + self.converted_weights_dict[f"dev_{dev_id}"][ + f"layer_{layer_idx}_model.norm.weight_dev_{dev_id}" + ] = gamma[dev_id] + + def __process_embedding_weights(self) -> None: + """Process embedding weights.""" + embedding_weight_file = self.special_treated_params[self.emb_name] + embedding_weight_file = os.path.join(self.model_dir, embedding_weight_file) + embedding_weights = load_file(embedding_weight_file, device=self.default_device)[ + self.emb_name + ] + self.emb_weights_dict = {"model.embed_tokens.weight": embedding_weights} + + def __post_process_weights( + self, + mla_weights: dict[str, dict[str, torch.Tensor]], + mlp_weights: dict[str, dict[str, torch.Tensor]], + mtp_weights: dict[str, dict[str, torch.Tensor]], + layer_idx: int, + ) -> None: + """Post process weights.""" + for weights_group in [mla_weights, mlp_weights, mtp_weights]: + for dev, params in weights_group.items(): + for param_name, tensor in params.items(): + new_key = f"layer_{layer_idx}_{param_name}_{dev}" + self.converted_weights_dict[dev][new_key] = tensor + + def to_tilert_weights(self) -> None: + torch.set_default_device(self.default_device) + + for i in range(self.total_layers): + if i not in self.target_layers: + logger.info(f"Skipping layer {i + 1} / {self.total_layers}") + continue + logger.info(f"Converting weight layer {i + 1} / {self.total_layers}") + + mla_weights, mlp_weights, mtp_weights = self.convert_a_layer(i) + self.__post_process_weights(mla_weights, mlp_weights, mtp_weights, i) + + self.__process_head_weights() + self.__process_embedding_weights() + + def _get_layer_num(file_name: str) -> tuple[int, int]: + """Extract layer number from filename like 'layer_XX.xxx'.""" + if "/" in file_name: + file_name = file_name.split("/")[-1] + + parts = file_name.split("_") + try: + layer_num = int(parts[1]) + except ValueError: + raise ValueError(f"Could not find layer number in parameter name: {file_name}") + try: + device_id = int(parts[-1]) + except ValueError: + raise ValueError(f"Could not find device id in parameter name: {file_name}") + return (device_id, layer_num) + + def _sort_key(filename: str) -> tuple[int, int]: + """Sort key function that returns (layer_num, device_id).""" + try: + return _get_layer_num(filename) + except ValueError: + return (999999, 999999) # If layer number not found, put at the end + + tilert_weights = sorted( + self.converted_weights_dict, key=lambda x: _sort_key(x), reverse=False + ) + pprint.pprint(tilert_weights) # noqa: T203 + + self.save_file_sharded( + self.converted_weights_dict, + "model.safetensors", + max_shard_size="5GB", + save_dir=self.save_dir, + ) + + def append_mtp_weights_to_safetensors( + self, + existing_save_dir: str, + max_shard_size: str = "5GB", + ) -> None: + """Append MTP layer weights to existing safetensors files. + + This method is used when layer 0-60 weights have already been converted, + and we only need to add the MTP layer (layer 61) weights. + + Note: lm_head.weight and model.norm.weight are already included in the + existing safetensors (converted with layer 0-60), so we only append: + - MTP preprocess weights (enorm, hnorm, eh_proj) + - MTP MLA weights + - MTP MoE weights + + Args: + existing_save_dir: Directory containing existing converted weights + max_shard_size: Maximum shard size for new safetensors files + """ + torch.set_default_device(self.default_device) + + # Load existing index.json + existing_index_file = os.path.join(existing_save_dir, "model.safetensors.index.json") + if not os.path.exists(existing_index_file): + raise ValueError(f"Existing index file not found: {existing_index_file}") + + with open(existing_index_file) as f: + existing_index = json.load(f) + + existing_weight_map: dict[str, str] = existing_index["weight_map"] + existing_total_size: int = existing_index["metadata"]["total_size"] + + # Find the next shard number + existing_shards = set(existing_weight_map.values()) + max_shard_num = 0 + for shard_name in existing_shards: + # Parse shard number from filename like "model.safetensors-00001-of-00010.safetensors" + parts = shard_name.replace(".safetensors", "").split("-") + if len(parts) >= 2: + try: + shard_num = int(parts[-2]) + max_shard_num = max(max_shard_num, shard_num) + except ValueError: + pass + + logger.info( + f"Found {len(existing_shards)} existing shards, max shard number: {max_shard_num}" + ) + + # Convert MTP layer (layer 61) weights + mtp_layer_idx = self.num_dense_layers + self.num_moe_layers # 61 + logger.info(f"Converting MTP layer {mtp_layer_idx} weights...") + + mla_weights, mlp_weights, mtp_weights = self.convert_a_layer(mtp_layer_idx) + + # Collect MTP layer weights for all devices + # Clone tensors to avoid shared memory issues when saving to safetensors + mtp_layer_weights: dict[str, torch.Tensor] = {} + for weights_group in [mla_weights, mlp_weights, mtp_weights]: + for dev, params in weights_group.items(): + for param_name, tensor in params.items(): + new_key = f"layer_{mtp_layer_idx}_{param_name}_{dev}" + mtp_layer_weights[new_key] = tensor.clone() + + logger.info(f"Collected {len(mtp_layer_weights)} MTP layer weight tensors") + + # Calculate size of new weights + new_weights_size = sum(self.get_tensor_size_bytes(t) for t in mtp_layer_weights.values()) + + # Save MTP weights to new shard file(s) + # Use a separate naming scheme to avoid modifying existing shards + max_size_bytes = self.parse_size(max_shard_size) + new_shards: list[ShardInfo] = [] + current_shard: dict[str, torch.Tensor] = {} + current_size = 0 + mtp_shard_index = 1 # Start from 1 for MTP shards + + for tensor_name, tensor in mtp_layer_weights.items(): + tensor_size = self.get_tensor_size_bytes(tensor) + if current_size + tensor_size > max_size_bytes and current_shard: + # Save current shard with MTP-specific naming + shard_filename = f"model_mtp_layer61-{mtp_shard_index:05d}.safetensors" + shard_path = os.path.join(existing_save_dir, shard_filename) + logger.info(f"Saving MTP shard to {shard_filename}") + save_file(current_shard, shard_path) + new_shards.append( + {"filename": shard_filename, "tensors": list(current_shard.keys())} + ) + current_shard = {} + current_size = 0 + mtp_shard_index += 1 + + current_shard[tensor_name] = tensor + current_size += tensor_size + + # Save the last shard + if current_shard: + shard_filename = f"model_mtp_layer61-{mtp_shard_index:05d}.safetensors" + shard_path = os.path.join(existing_save_dir, shard_filename) + logger.info(f"Saving MTP shard to {shard_filename}") + save_file(current_shard, shard_path) + new_shards.append({"filename": shard_filename, "tensors": list(current_shard.keys())}) + + # Update weight_map with new MTP weights (existing shards remain unchanged) + for shard in new_shards: + for tensor_name in shard["tensors"]: + existing_weight_map[tensor_name] = shard["filename"] + + # Update index.json + updated_index = { + "metadata": {"total_size": existing_total_size + new_weights_size}, + "weight_map": existing_weight_map, + } + + with open(existing_index_file, "w") as f: + json.dump(updated_index, f, indent=2) + + logger.info(f"Added {len(new_shards)} new MTP shard(s)") + logger.info(f"New total size: {existing_total_size + new_weights_size}") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model_type", type=str, required=True) + parser.add_argument("--model_dir", type=str, required=True) + parser.add_argument("--save_dir", type=str, required=True) + parser.add_argument("--test_mode", action="store_true", help="Test mode") + parser.add_argument( + "--append_mtp", + action="store_true", + help="Append MTP layer (layer 61) weights to existing safetensors. " + "Use this when layer 0-60 weights have already been converted.", + ) + args = parser.parse_args() + + model_type = args.model_type + if model_type == "deepseek-v32": + model_args = ModelArgsDsav32() + elif model_type == "glm-5": + model_args = ModelArgsGLM5() + else: + raise ValueError(f"Invalid model type: {model_type}") + + converter = WeightConverter(model_args, 8, args.model_dir, args.save_dir, args.test_mode) + if args.append_mtp: + converter.append_mtp_weights_to_safetensors(args.save_dir) + else: + converter.to_tilert_weights() diff --git a/python/models/preprocess/weight_utils.py b/python/models/preprocess/weight_utils.py deleted file mode 100644 index 5330e14..0000000 --- a/python/models/preprocess/weight_utils.py +++ /dev/null @@ -1,568 +0,0 @@ -"""Weight loading and preprocessing utilities.""" - -import os -from typing import Any - -import torch - -from tilert import logger -from tilert.models.deepseek_config import get_world_size - -__all__ = [ - "print_weights_info", - "WeightLoader", - "DownAllreduceWeightsConverter", - "RMSNormProjQAKVAKIWeightsConverter", - "RMSNormHeadProjWeightsConverter", - "ExpertSelectUpGateSiLUWeightsConverter", - "RMSNormUpGateSiLUWeightsConverter", -] - - -def print_weights_info(weights_path: str) -> None: - """Print the information of the weights.""" - try: - weights = torch.load( - weights_path, - map_location="cuda", - weights_only=True, - ) - print("Successfully loaded weights. Available keys:") - for key in weights.keys(): - print(f" - {key}, shape: {weights[key].shape}") - except Exception as e: - print(f"Error loading weights: {e}") - raise - - -class WeightLoader: - """Weight loader for TileRT models.""" - - def __init__( - self, - layer_idx: int = 0, - golden_weights_dir: str = "", - tilert_weights_dir: str = "", - ) -> None: - """Initialize the weight loader. - - Args: - layer_idx: Layer index. - golden_weights_dir: Path to golden weights directory. - tilert_weights_dir: Path to tilert weights directory. - """ - self.layer_idx = layer_idx - self.golden_weights_dir = golden_weights_dir - self.tilert_weights_dir = tilert_weights_dir - - self.weights_loaded_golden = False - self.weights_dict_golden: dict[str, dict[str, Any]] = {} - - self.weights_loaded_tilert = False - self.weights_dict_tilert: dict[str, dict[str, Any]] = {} - - def get_weight_file_path(self, device_id: int = 0, is_tilert: bool = False) -> str: - """Get the weight file path for a given layer. - - Args: - device_id: Device id. - is_tilert: Whether the weights are for tilert. - """ - if is_tilert: - return os.path.join( - self.tilert_weights_dir, - f"tilert_deepseek_v32.layer_{self.layer_idx}.dev_{device_id}.weights.pt", - ) - - return os.path.join( - self.golden_weights_dir, - f"deepseek_v32.layer_{self.layer_idx}.weights.pt", - ) - - def get_weight_prefix(self) -> str: - """Get the weight file prefix for a given layer.""" - return f"model.layers.{self.layer_idx}." - - def register_weights( - self, weights_config: dict[str, dict[str, Any]], is_tilert: bool = False - ) -> None: - """Register weights configuration. - - Args: - weights_config: Dictionary mapping weight names to their configurations. - Each configuration should have 'shape', 'dtype', and 'data' keys. - is_tilert: Whether the weights are for tilert. - """ - if is_tilert: - self.weights_dict_tilert.update(weights_config) - else: - self.weights_dict_golden.update(weights_config) - - def check_shape( - self, data_shape: torch.Size, config_shape: tuple[int, ...], split_method: str = "no_split" - ) -> None: - """Check if the shape of the data is the same as the shape in the weights configuration. - - Args: - data_shape: Shape of the data tensor. - config_shape: Expected shape from the configuration. - - Raises: - ValueError: If the shapes don't match. - """ - data_shape = tuple(data_shape) - config_shape = tuple(config_shape) - - if split_method == "row_split": - new_shape = (data_shape[0], data_shape[1] // get_world_size()) - elif split_method == "column_split": - new_shape = (data_shape[0] // get_world_size(), data_shape[1]) - elif split_method == "no_split": - new_shape = data_shape - else: - raise ValueError(f"Invalid split method: {split_method}") - - if new_shape != config_shape: - raise ValueError(f"Shape mismatch: got {new_shape}, expected {config_shape}") - - def load_weights(self, weights_path: str, device_id: int = 0) -> None: - """Load weights from the weights path. - - Args: - weights_path: Path to weights file. - device_id: Device id. - """ - if not os.path.exists(weights_path): - raise ValueError(f"Weights path {weights_path} does not exist") - - # TODO(ying): Enhance the error handling for weights loading. - device = torch.device(f"cuda:{device_id}") - weights = torch.load( - weights_path, - map_location=device, - weights_only=True, - ) - - for key in self.weights_dict_golden: - weight_name = self.get_weight_prefix() + key - if weight_name not in weights: - raise ValueError(f"Weight {weight_name} not found in weights file") - - data = weights[weight_name] - logger.info(f"Loaded weight {weight_name} with shape {data.shape}") - - item = self.weights_dict_golden[key] - split_method = item.get("split_method", "no_split") - self.check_shape(data.shape, item["shape"], split_method) - - if split_method == "row_split": - split_size = data.shape[1] // get_world_size() - start_idx = device_id * split_size - end_idx = start_idx + split_size - data = data[:, start_idx:end_idx] - elif split_method == "column_split": - split_size = data.shape[0] // get_world_size() - start_idx = device_id * split_size - end_idx = start_idx + split_size - data = data[start_idx:end_idx, :] - elif split_method == "no_split": - pass - else: - raise ValueError(f"Invalid split method: {split_method}") - - if isinstance(item["data"], torch.Tensor): - item["data"].copy_(data) - else: - item["data"] = data - - self.weights_loaded_golden = True - - def load_tilert_weights(self, weights_path: str, device_id: int = 0) -> None: - """Load tilert weights from the weights path. - - Args: - weights_path: Path to weights file. - device_id: Device id. - """ - if not os.path.exists(weights_path): - raise ValueError(f"Weights path {weights_path} does not exist") - - device = torch.device(f"cuda:{device_id}") - weights = torch.load( - weights_path, - map_location=device, - weights_only=True, - ) - - for key in self.weights_dict_tilert: - if key not in weights: - raise ValueError(f"Weight {key} not found in weights file") - - data = weights[key] - logger.info(f"Loaded weight {key} with shape {data.shape}") - if isinstance(self.weights_dict_tilert[key]["data"], torch.Tensor): - self.check_shape(tuple(data.shape), self.weights_dict_tilert[key]["shape"]) - self.weights_dict_tilert[key]["data"].copy_(data) - else: - self.weights_dict_tilert[key]["data"] = data - - self.weights_loaded_tilert = True - - def get_weight(self, name: str, from_tilert: bool = False) -> Any: - """Get a weight by name. - - Args: - name: Weight name. - - Returns: - Weight data. - - Raises: - ValueError: If weight is not found or not loaded. - """ - weight_dict = self.weights_dict_tilert if from_tilert else self.weights_dict_golden - - if name not in weight_dict: - raise ValueError(f"Weight {name} not registered") - - if from_tilert: - if not self.weights_loaded_tilert: - raise ValueError("Tilert weights not loaded. Call load_tilert_weights first.") - elif not self.weights_loaded_golden: - raise ValueError("Golden weights not loaded. Call load_weights first.") - - return weight_dict[name]["data"] - - -class RMSNormProjQAKVAKIWeightsConverter: - """Weights converter class.""" - - @staticmethod - def tilert_to_common( - tilert_wqkv_a: torch.Tensor, - tilert_wqkv_a_scales: torch.Tensor, - tilert_attn_norm_weight: torch.Tensor, - ) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - ]: - """ - Convert tilert weights to common weights. - - Args: - tilert_wqkv_a: Tilert weight tensor. - tilert_wqkv_a_scales: Tilert weight scale tensor. - tilert_attn_norm_weight: Tilert attention norm weight tensor. - Returns: - tuple: Common weights. - """ - wq_a = tilert_wqkv_a[:1536] # 1536, 7168 - wkv_a = tilert_wqkv_a[1536 : 1536 + 576] # 576, 7168 - wk = tilert_wqkv_a[1536 + 576 :] # 128, 7168 - - wqkv_a_scales_0 = tilert_wqkv_a_scales[:128, :].reshape(16, 8, 64) - wqkv_a_scales_0 = wqkv_a_scales_0[:, 0, :].reshape(16, 64) - wqkv_a_scales_1 = tilert_wqkv_a_scales[128:129, :] # 1, 64 - wqkv_a_scales_2 = tilert_wqkv_a_scales[129:, :] # 1, 64 - wqkv_a_scales_swizzled = torch.cat( - [wqkv_a_scales_0, wqkv_a_scales_1, wqkv_a_scales_2], dim=0 - ) - wqkv_scales = torch.zeros( - (18, 56), dtype=torch.bfloat16, device=tilert_wqkv_a_scales.device - ) - - for i in range(64): - if ((i % 8) * 8 + i // 8) < 56: - wqkv_scales[:, ((i % 8) * 8 + i // 8)] = wqkv_a_scales_swizzled[:, i] - wq_a_scale = wqkv_scales[:12, :] # 12, 56 - wkv_a_scale = wqkv_scales[12:17, :] # 5, 56 - wk_scale = wqkv_scales[17:, :] # 1, 56 - - attn_norm_weight = tilert_attn_norm_weight - return wq_a, wq_a_scale, wkv_a, wkv_a_scale, wk, wk_scale, attn_norm_weight - - @staticmethod - def common_to_tilert_native_bf16_warp_gemv( - wq_a: torch.Tensor, - wq_a_scale: torch.Tensor, - wkv_a: torch.Tensor, - wkv_a_scale: torch.Tensor, - wk: torch.Tensor, - wk_scale: torch.Tensor, - attn_norm_weight: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Convert common weights to weights for tilert native bf16 warp gemv op. - - Args: - wq_a: Common weight tensor. - wq_a_scale: Common weight scale tensor. - wkv_a: Common weight tensor. - wkv_a_scale: Common weight scale tensor. - wk: Common weight tensor. - wk_scale: Common weight scale tensor. - attn_norm_weight: Common attention norm weight tensor. - Returns: - tuple: Tilert weights for native bf16 warp gemv op. - """ - wq_a_scale = wq_a_scale.reshape((12, 56, 1)).repeat(1, 1, 128).reshape((12, 1, 7168)) - wq_a_scale = wq_a_scale.repeat(1, 128, 1).reshape((1536, 7168)) - wkv_a_scale = wkv_a_scale.reshape((5, 56, 1)).repeat(1, 1, 128).reshape((5, 1, 7168)) - wkv_a_scale = wkv_a_scale.repeat(1, 128, 1).reshape((-1, 7168)) - wkv_a_scale = wkv_a_scale[:576] - wk_scale = wk_scale.reshape((1, 56, 1)).repeat(1, 1, 128).reshape((1, 1, 7168)) - wk_scale = wk_scale.repeat(1, 128, 1).reshape((128, 7168)) - wq_a = wq_a.reshape((1536, 7168)).float() * wq_a_scale.float() - wkv_a = wkv_a.reshape((576, 7168)).float() * wkv_a_scale.float() - wk = wk.reshape((128, 7168)).float() * wk_scale.float() - # concatenate the weights - weights = torch.cat([wq_a, wkv_a, wk], dim=0) - assert weights.shape == (1536 + 576 + 128, 7168) - - weights = weights.reshape(140, 16, 7, 1024) - weights = weights.transpose(1, 2) # 140, 7, 16, 1024 - return weights.to(torch.bfloat16).contiguous(), attn_norm_weight.clone() - - -class ExpertSelectUpGateSiLUWeightsConverter: - """Weights converter class.""" - - @staticmethod - def _swizzle_mma_16x32(mat_in: torch.Tensor) -> torch.Tensor: - assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32 - # PTX isa fig.88 - pre_shape = mat_in.shape[:-2] - mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4) - return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2) - - @staticmethod - def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor: - assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16 - # PTX isa fig.88 - pre_shape = mat_in.shape[:-2] - mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4) - return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2) - - @staticmethod - def tilert_to_tilert_144sm( - mat_in: torch.Tensor, mat_scale_in: torch.Tensor, mma_type: str | None = None - ) -> torch.Tensor: - """ - Convert tilert weights and scales to tilert_144sm input format. - - Args: - mat_in: tilert weights - mat_scale_in: tilert scales - mma_type: MMA type, None,"16x32" or "16x16" - Returns: - tilert_144sm weights and scales - """ - exp_num = mat_in.shape[0] - assert mat_in.shape == (exp_num, 512, 7168) - assert mat_scale_in.shape == (exp_num, 4, 64) - weights_trt = mat_in.reshape(exp_num, 128, 4, 7168) - weights_w1 = weights_trt[:, :, :2].reshape(exp_num, 256, 7168) - weights_w3 = weights_trt[:, :, 2:].reshape(exp_num, 256, 7168) - # to 16x1024 blocks - weights_w1 = weights_w1.reshape(exp_num, 16, 16, 7, 1024).transpose(2, 3) - weights_w3 = weights_w3.reshape(exp_num, 16, 16, 7, 1024).transpose(2, 3) - if mma_type == "16x32": - weights_w1 = weights_w1.reshape(exp_num, 16, 7, 16, 32, 32).transpose(3, 4) - weights_w1 = ExpertSelectUpGateSiLUWeightsConverter._swizzle_mma_16x32(weights_w1) - weights_w1 = weights_w1.reshape(exp_num, 16, 7, 16, 1024) - weights_w3 = weights_w3.reshape(exp_num, 16, 7, 16, 32, 32).transpose(3, 4) - weights_w3 = ExpertSelectUpGateSiLUWeightsConverter._swizzle_mma_16x32(weights_w3) - weights_w3 = weights_w3.reshape(exp_num, 16, 7, 16, 1024) - elif mma_type == "16x16": - weights_w1 = weights_w1.reshape(exp_num, 16, 7, 16, 64, 16).transpose(3, 4) - weights_w1 = ExpertSelectUpGateSiLUWeightsConverter._swizzle_mma_16x16(weights_w1) - weights_w1 = weights_w1.reshape(exp_num, 16, 7, 16, 1024) - weights_w3 = weights_w3.reshape(exp_num, 16, 7, 16, 64, 16).transpose(3, 4) - weights_w3 = ExpertSelectUpGateSiLUWeightsConverter._swizzle_mma_16x16(weights_w3) - weights_w3 = weights_w3.reshape(exp_num, 16, 7, 16, 1024) - - weights = torch.cat([weights_w1, weights_w3], dim=3) - assert weights.shape == (exp_num, 16, 7, 32, 1024) - weights = weights.reshape(exp_num, 16, 7, 32 * 1024) - - # For scales, first unswizzle - scales_unswizzled = torch.zeros(exp_num, 4, 56) - for i in range(64): - if ((i % 8) * 8 + i // 8) < 56: - scales_unswizzled[..., ((i % 8) * 8 + i // 8)] = mat_scale_in[..., i] - scales_unswizzled = scales_unswizzled.reshape(exp_num, 2, 2, 56) - - scales_w1 = scales_unswizzled[:, :, :1].repeat(1, 1, 8, 1).reshape(exp_num, 16, 1, 7, 8) - scales_w1 = scales_w1.transpose(2, 3) - scales_w3 = scales_unswizzled[:, :, 1:].repeat(1, 1, 8, 1).reshape(exp_num, 16, 1, 7, 8) - scales_w3 = scales_w3.transpose(2, 3) - scales = torch.cat([scales_w1, scales_w3], dim=3) - assert scales.shape == (exp_num, 16, 7, 2, 8) - scales = ( - scales.reshape(exp_num, 16, 7, 2 * 8).to(torch.bfloat16).view(dtype=torch.float8_e4m3fn) - ) - weights_and_scales = torch.zeros( - exp_num, 16, 7, 32 * 1024 + 128, dtype=torch.float8_e4m3fn, device=mat_in.device - ) - weights_and_scales[:, :, :, : 32 * 1024].copy_(weights) - weights_and_scales[:, :, :, 32 * 1024 : 32 * 1024 + 32].copy_(scales) - return weights_and_scales - - @staticmethod - def tilert_to_tilert_144sm_mma( - mat_in: torch.Tensor, mat_scale_in: torch.Tensor, mma_type: str = "16x32" - ) -> torch.Tensor: - """ - Convert tilert weights and scales to tilert_144sm_mma input format. - - Args: - mat_in: tilert weights - mat_scale_in: tilert scales - Returns: - tilert_144sm weights and scales - """ - return ExpertSelectUpGateSiLUWeightsConverter.tilert_to_tilert_144sm( - mat_in, mat_scale_in, mma_type - ) - - -class RMSNormHeadProjWeightsConverter: - """Weights converter class.""" - - @staticmethod - def tilert_to_tilert_native_bf16_warp_gemv( - tilert_weight_in: torch.Tensor, - ) -> torch.Tensor: - """Convert TILERT weights to TILERT native bf16 warp gemv weights.""" - weights = tilert_weight_in.reshape(1010, 16, 7, 1024) - weights = weights.transpose(1, 2).reshape(7070, 16, 1024) - return weights.contiguous() - - -class RMSNormUpGateSiLUWeightsConverter: - """Weights converter class.""" - - @staticmethod - def tilert_to_tilert_144sm( - mat_in: torch.Tensor, - mat_scale_in: torch.Tensor, - mma_type: str | None = None, - ) -> torch.Tensor: - """ - Convert tilert weights and scales to tilert_144sm input format. - - Args: - mat_in: tilert weights - mat_scale_in: tilert scales - Returns: - tilert_144sm weights and scales - """ - return ExpertSelectUpGateSiLUWeightsConverter.tilert_to_tilert_144sm( - mat_in, mat_scale_in, mma_type - ) - - @staticmethod - def tilert_to_tilert_144sm_mma( - mat_in: torch.Tensor, - mat_scale_in: torch.Tensor, - mma_type: str = "16x32", - ) -> torch.Tensor: - """Convert tilert weights and scales to tilert_144sm_mma input format.""" - return ExpertSelectUpGateSiLUWeightsConverter.tilert_to_tilert_144sm_mma( - mat_in, mat_scale_in, mma_type - ) - - -class UnProjOAllreduceWeightsConverter: - """Weights converter class.""" - - @staticmethod - def tilert_to_tilert_112sm_mma( - mat_in: torch.Tensor, - mat_scale_in: torch.Tensor, - ) -> torch.Tensor: - """ - Convert tilert weights to tilert_112sm_mma input format. - - Args: - mat_in: tilert weights, [7168, 2048] - mat_scale_in: tilert scales, [896, 16] - Returns: - tilert_112sm weights, scales - """ - swizzle_for_mma_16x32 = True - assert mat_in.shape == (7168, 2048) - assert mat_scale_in.shape == (896, 16) - weights_trt = mat_in.reshape(112, 64, 2048) - # to 64*512 blocks - weights_trt = weights_trt.reshape(112, 64, 4, 512).transpose(1, 2) # (112, 4, 64, 512) - if swizzle_for_mma_16x32: - # to (112, 4, 4(n), 16(k), 16, 32) - weights_trt = weights_trt.reshape(112, 4, 4, 16, 16, 32).transpose(-2, -3) - weights_trt = ExpertSelectUpGateSiLUWeightsConverter._swizzle_mma_16x32(weights_trt) - weights_trt = weights_trt.reshape(112, 4, 4 * 16 * 16 * 32) - - # For scales - scales_trt = mat_scale_in.reshape(56, 16, 16)[:, 0, :] - scales_trt = scales_trt.reshape(56, 16) - return weights_trt.contiguous(), scales_trt.contiguous() - - -class DownAllreduceWeightsConverter: - """Weights converter class.""" - - @staticmethod - def _swizzle_mma_16x32(mat_in: torch.Tensor) -> torch.Tensor: - assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32 - # PTX isa fig.88 - pre_shape = mat_in.shape[:-2] - mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4) - return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2) - - @staticmethod - def _swizzle_mma_8x32(mat_in: torch.Tensor) -> torch.Tensor: - assert mat_in.shape[-2] == 8 and mat_in.shape[-1] == 32 - # PTX isa fig.88 - pre_shape = mat_in.shape[:-2] - return mat_in.reshape(*pre_shape, 8, 2, 4, 4).transpose(-2, -3).contiguous() - - @staticmethod - def tilert_to_tilert_mma( - mat_in: torch.Tensor, scale_in: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Convert tilert weights and scales to tilert_mma input format. - - Args: - mat_in: tilert weights - mat_scale_in: tilert scales - Returns: - tilert_mma weights and scales - """ - exp_num = mat_in.shape[0] - mat_in_s = mat_in.reshape(exp_num, 128, 56, 256) - # mat_in_s[:, :, 48:56] = 0; - mat_in_0 = mat_in_s[:, :, :16].reshape(exp_num, 128, 16, 8, 32).transpose(2, 3) - mat_in_0 = DownAllreduceWeightsConverter._swizzle_mma_16x32(mat_in_0).reshape( - exp_num, 128, -1 - ) - mat_in_1 = mat_in_s[:, :, 16:32].reshape(exp_num, 128, 16, 8, 32).transpose(2, 3) - mat_in_1 = DownAllreduceWeightsConverter._swizzle_mma_16x32(mat_in_1).reshape( - exp_num, 128, -1 - ) - mat_in_2 = mat_in_s[:, :, 32:48].reshape(exp_num, 128, 16, 8, 32).transpose(2, 3) - mat_in_2 = DownAllreduceWeightsConverter._swizzle_mma_16x32(mat_in_2).reshape( - exp_num, 128, -1 - ) - mat_in_3 = mat_in_s[:, :, 48:56].reshape(exp_num, 128, 8, 8, 32).transpose(2, 3) - mat_in_3 = DownAllreduceWeightsConverter._swizzle_mma_8x32(mat_in_3).reshape( - exp_num, 128, -1 - ) - - mat_in_swizzled = torch.cat([mat_in_0, mat_in_1, mat_in_2, mat_in_3], dim=2) - return mat_in_swizzled.reshape(exp_num, 7168, 256).contiguous(), scale_in diff --git a/python/models/utils.py b/python/models/utils.py index e85242b..b5e81e5 100644 --- a/python/models/utils.py +++ b/python/models/utils.py @@ -96,7 +96,7 @@ def linear_ramp_factor(min_value: float, max_value: float, dim: int) -> torch.Te return torch.clamp(linear_func, 0, 1) freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - if seqlen > args.original_seq_len: + if factor is not None and seqlen > args.original_seq_len: low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len) smooth = 1 - linear_ramp_factor(low, high, dim // 2) freqs = freqs / factor * (1 - smooth) + freqs * smooth diff --git a/python/profiler/__init__.py b/python/profiler/__init__.py new file mode 100644 index 0000000..e9b1cf9 --- /dev/null +++ b/python/profiler/__init__.py @@ -0,0 +1 @@ +"""Profiler utilities for TileRT.""" diff --git a/python/profiler/utils.py b/python/profiler/utils.py new file mode 100644 index 0000000..ecd83f6 --- /dev/null +++ b/python/profiler/utils.py @@ -0,0 +1,477 @@ +import os +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch + +from tilert.utils import SLICES_FOR_TILERT_OP + +# Worker names used by ExecPlanDescriptor (previously from scheduling.plan_v0) +WORKER_NAMES = [ + "Init", + "Prefetch", + "Compute", + "ExtraTask1/SyncIo", + "ExtraTask2/IoP0", + "ExtraTask3/IoP2", + "ExtraTask4", + "ExtraTask5", +] + +try: + from openpyxl import Workbook + from openpyxl.cell import Cell + from openpyxl.styles import Alignment, Border, PatternFill, Side + from openpyxl.styles.colors import COLOR_INDEX + from openpyxl.worksheet.worksheet import Worksheet +except ImportError: + print("openpyxl is not installed, profile logs will not be visualized") + Workbook = None + + +__all__ = [ + "ExcelStyleConfigs", + "ExecPlanDescriptor", + "WorkerBookVisualizer", + "visualize_profile_logs", + "parse_profile_log_tensor", + "parse_op_time", +] + + +@dataclass +class ExcelStyleConfigs: + """Excel style configurations.""" + + # 2 col * 3 stream + cols_per_worker: int = 6 + ns_per_tick: int = 1000 + + +@dataclass +class ExecPlanDescriptor: + """Exec plan descriptor.""" + + workers_def: list + op_lists: list + + +class WorkerBookVisualizer: + """Sheet visualizer.""" + + def __init__(self, exec_plan_desc: ExecPlanDescriptor): + self.exec_plan_desc = exec_plan_desc + + self.wb = Workbook() + self.wb.remove(self.wb.active) + + # Excel configs + self.style_configs = ExcelStyleConfigs() + + self.op_cols_splits = 3 + + self.time_bar_cols = 1 + self.op_stat_bar_cols = 6 + + workers_num = len(self.exec_plan_desc.workers_def) + self.op_vis_bar_cols = workers_num * self.style_configs.cols_per_worker + assert self.op_stat_bar_cols % self.op_cols_splits == 0 + + @property + def time_bar_next_col(self) -> int: + return self.time_bar_cols + 1 + + @property + def op_stat_bar_next_col(self) -> int: + return self.time_bar_next_col + self.op_stat_bar_cols + + @property + def op_vis_bar_next_col(self) -> int: + return self.op_stat_bar_next_col + self.op_vis_bar_cols + + @staticmethod + def add_region_cell( + ws: Worksheet, + value: str, + start_row: int, + start_col: int, + row_size: int = 1, + col_size: int = 1, + color_offset: int = -1, + ) -> Cell: + cell = ws.cell(row=start_row, column=start_col, value=value) + cell.alignment = Alignment(horizontal="center", vertical="center", wrap_text=True) + if color_offset >= 0: + cell.fill = PatternFill( + start_color=COLOR_INDEX[50 + color_offset], + end_color=COLOR_INDEX[50 + color_offset], + fill_type="solid", + ) + ws.merge_cells( + start_row=start_row, + start_column=start_col, + end_row=start_row + row_size - 1, + end_column=start_col + col_size - 1, + ) + return cell + + def init_layout(self, ws: Worksheet) -> None: + workers_name = self.exec_plan_desc.workers_def + worker_cols = self.style_configs.cols_per_worker + + self.add_region_cell(ws, "Op Info", 1, self.time_bar_next_col, 1, self.op_stat_bar_cols) + + for worker_id, worker_name in enumerate(workers_name): + start_col = worker_cols * worker_id + self.op_stat_bar_next_col + self.add_region_cell(ws, worker_name, 1, start_col, 1, worker_cols) + + def _parse_inst_info( + self, insts_info: list[tuple[str, float, int] | tuple[str, float] | str], op_idx: int + ) -> tuple[str, float, int]: + inst_info = insts_info[op_idx] + if isinstance(inst_info, str): + op_name, op_cost = inst_info, 0.0 + op_stream = op_idx % self.op_cols_splits + elif len(inst_info) == 2: + op_name, op_cost = inst_info + op_stream = op_idx % self.op_cols_splits + elif len(inst_info) == 3: + op_name, op_cost, op_stream = inst_info + else: + raise TypeError("Invalid inst_info format") + return op_name, op_cost, op_stream + + def add_region_cell_by_time( + self, + ws: Worksheet, + op_show_info: str, + start_time: float, + end_time: float, + op_col_start: int, + op_col_size: int, + ns_tick: int, + color_offset: int = -1, + ) -> Cell: + op_start_row_idx = np.round(start_time / ns_tick).astype(np.int32) + 2 + op_end_row_idx = np.round(end_time / ns_tick).astype(np.int32) + 2 + op_end_row_idx = max(op_end_row_idx, op_start_row_idx) + return self.add_region_cell( + ws, + op_show_info, + op_start_row_idx, + op_col_start, + max(op_end_row_idx - op_start_row_idx, 1), + op_col_size, + color_offset, + ) + + def timeline_visual_region( + self, + ws: Worksheet, + profile_logs: np.ndarray, + insts_info: list[tuple[str, float, int] | tuple[str, float] | str], + ignore_prefilling: bool = True, + ) -> None: + ns_tick = self.style_configs.ns_per_tick + self.init_layout(ws) + + total_end_time = 0 + for op_idx, op_log in enumerate(profile_logs): + op_name, op_cost, op_stream = self._parse_inst_info(insts_info, op_idx) + + if op_stream >= self.op_cols_splits: + print(f"stream_id (aka col_id) must < {self.op_cols_splits}") + raise ValueError + + valid_mask: np.ndarray = op_log >= 0 + if ignore_prefilling: + valid_mask[2:4] = False + + if np.count_nonzero(valid_mask) == 0: + continue + + op_start_time = np.min(op_log, where=valid_mask, initial=np.inf) + op_end_time = np.max(op_log, where=valid_mask, initial=-np.inf) + total_end_time = max(total_end_time, op_end_time) + + op_cost_theory = op_cost / 1000 + op_cost_actual = (op_end_time - op_start_time) / 1000 + op_bw_utils = f"{op_cost_theory / op_cost_actual * 100:.2f}" + + op_show_info = ( + f"{op_name}\n" + + f"BW Util: {op_bw_utils}%\n" + + f"Actual: {op_cost_actual:.2f}us\n" + + f"Theoretical: {op_cost_theory:.2f}us\n" + + f"Start Time: {op_start_time / 1000:.2f}us\n" + + f"End Time: {op_end_time / 1000:.2f}us" + ) + op_col_size = self.op_stat_bar_cols // self.op_cols_splits + op_col_start = self.time_bar_next_col + op_stream * op_col_size + self.add_region_cell_by_time( + ws, + op_show_info, + op_start_time, + op_end_time, + op_col_start, + op_col_size, + ns_tick, + ) + + for queue_idx, (start_time, end_time) in enumerate(zip(op_log[::2], op_log[1::2])): + if start_time < 0 or end_time < 0: + continue + task_dur = (end_time - start_time) / 1000 + task_bw_utils = f"{min(100, op_cost_theory / task_dur * 100):.2f}" + task_show_info = ( + f"{op_name}\n" + + f"Dur: {task_dur:.2f}us\n" + + f"BW Util. {task_bw_utils}%:\n" + + f"Start: {start_time / 1000:.2f}us\n" + + f"End: {end_time / 1000:.2f}us" + ) + task_col_size = self.style_configs.cols_per_worker // self.op_cols_splits + task_col_start = ( + self.op_stat_bar_next_col + + queue_idx * self.style_configs.cols_per_worker + + op_stream * task_col_size + ) + cell = self.add_region_cell_by_time( + ws, + task_show_info, + start_time, + end_time, + task_col_start, + task_col_size, + ns_tick, + queue_idx, + ) + cell.border = Border( + left=Side(style="thin"), + right=Side(style="thin"), + top=Side(style="thin"), + bottom=Side(style="thin"), + ) + + for dur_idx, dur_start in enumerate(range(0, int(total_end_time), ns_tick)): + ws.cell(row=dur_idx + 2, column=1, value=f"{(dur_start + ns_tick) / 1000:.2f}") + + def brief_table_region( + self, + ws: Worksheet, + profile_logs: np.ndarray, + insts_info: list[tuple[str, float, int] | tuple[str, float] | str], + ) -> None: + for op_idx, op_log in enumerate(profile_logs): + op_name, _, _ = self._parse_inst_info(insts_info, op_idx) + + ws.cell(row=op_idx + 2, column=self.op_vis_bar_next_col, value=op_name) + + for queue_idx, (start_time, end_time) in enumerate(zip(op_log[::2], op_log[1::2])): + if start_time < 0 or end_time < 0: + continue + task_dur = (end_time - start_time) / 1000 + ws.cell( + row=op_idx + 2, column=self.op_vis_bar_next_col + queue_idx + 1, value=task_dur + ) + + def add_sheet(self, profile_logs: np.ndarray, sheet_name: str) -> "WorkerBookVisualizer": + """Add a sheet to the workbook.""" + wb = self.wb + insts_info = self.exec_plan_desc.op_lists + + ws = wb.create_sheet(sheet_name) + self.timeline_visual_region(ws, profile_logs, insts_info) + self.brief_table_region(ws, profile_logs, insts_info) + + return self + + def add_sm_brief_sheet( + self, profile_logs: np.ndarray, sheet_name: str + ) -> "WorkerBookVisualizer": + """Add a brief sheet to workbook which contains min/max start/end and duration among SMs""" + wb = self.wb + insts_info = self.exec_plan_desc.op_lists + ws = wb.create_sheet(sheet_name) + + profile_logs = np.transpose(profile_logs, (1, 0, 2)) + + # 1. init layout + workers_name = self.exec_plan_desc.workers_def + worker_metric_def = [ + "min_start", + "max_end", + "min_dur", + "max_dur", + "mean_dur", + "std_dur", + ] + + worker_cols = len(worker_metric_def) + + self.add_region_cell(ws, "Op Info", 1, self.time_bar_next_col, 1, self.op_stat_bar_cols) + + for worker_id, worker_name in enumerate(workers_name): + start_col = worker_cols * worker_id + self.op_stat_bar_next_col + self.add_region_cell(ws, worker_name, 1, start_col, 1, worker_cols) + for metric_id, metric_name in enumerate(worker_metric_def): + start_col_metric = start_col + metric_id + self.add_region_cell(ws, metric_name, 2, start_col_metric, 1, 1) + + # 2. calc metrics + # profile_logs: (num_ops, num_sm, num_task*2) + for op_idx, op_profile_log in enumerate(profile_logs): + valid_mask = (op_profile_log >= 0) & (op_profile_log < 1e9) + # skip if this op is fully invalid + if not np.any(valid_mask): + continue + + op_name, _, _ = self._parse_inst_info(insts_info, op_idx) + self.add_region_cell(ws, op_name, op_idx + 3, self.time_bar_next_col, 1, 2) + + for queue_idx in range(op_profile_log.shape[1] // 2): + starts = op_profile_log[:, queue_idx * 2] + ends = op_profile_log[:, queue_idx * 2 + 1] + + valid_mask = ( + (starts >= 0) & (starts < 1e9) & (ends >= 0) & (ends < 1e9) & (starts <= ends) + ) + + valid_starts = starts[valid_mask] / 1000 + valid_ends = ends[valid_mask] / 1000 + + if len(valid_starts) == 0: + continue + + min_start = np.min(valid_starts) + max_end = np.max(valid_ends) + durations = valid_ends - valid_starts + + metrics_values = [ + min_start, + max_end, + np.min(durations), + np.max(durations), + np.mean(durations), + np.std(durations), + ] + + # row_idx start from 3, because {1: work_name, 2: metric_name} + # col_idx start from worker::start_col + start_row = op_idx + 3 + start_col = worker_cols * queue_idx + self.op_stat_bar_next_col + color_offset = queue_idx + + for i, value in enumerate(metrics_values): + # color mean and std dev + cell_color = color_offset if i >= 4 else -1 + self.add_region_cell(ws, value, start_row, start_col + i, 1, 1, cell_color) + + return self + + def save(self, out_path: str) -> None: + """Save the workbook to a file.""" + os.makedirs(os.path.dirname(out_path), exist_ok=True) + self.wb.save(out_path) + + +def visualize_profile_logs( + all_profile_logs: np.ndarray, + out_path: str, + inst2opname: list[tuple[str, float, int] | tuple[str, float] | str], + with_mean: bool = False, + with_max: bool = False, +) -> None: + """Visualize profile logs.""" + valid_ctas = np.argwhere(np.any(all_profile_logs != 0, axis=(1, 2)))[:, 0] + filtered_logs = all_profile_logs[valid_ctas] + filtered_masks = np.logical_and(filtered_logs >= 0, filtered_logs < 1e9) + mean_profile_logs = np.mean(filtered_logs, axis=0, where=filtered_masks) + mean_profile_logs[np.isnan(mean_profile_logs)] = -1 + if filtered_logs.size == 0: + return + assemble_profile_logs = np.zeros_like(filtered_logs[0]) + assemble_profile_logs[:, ::2] = np.min( + filtered_logs[..., ::2], axis=0, where=filtered_masks[..., ::2], initial=np.inf + ) + assemble_profile_logs[:, 1::2] = np.max( + filtered_logs[..., 1::2], axis=0, where=filtered_masks[..., 1::2], initial=-np.inf + ) + assemble_profile_logs[np.isinf(assemble_profile_logs)] = -1 + + visualizer = WorkerBookVisualizer(ExecPlanDescriptor(WORKER_NAMES, inst2opname)) + if with_mean: + visualizer.add_sheet(mean_profile_logs, "mean") + if with_max: + raise NotImplementedError("with_max is not implemented") + + visualizer.add_sm_brief_sheet(filtered_logs, "mean_sm_brief") + for block_idx, profile_logs in enumerate(filtered_logs): + profile_logs[profile_logs > 1e9] = -1 + visualizer.add_sheet(profile_logs, f"block_{block_idx}") + visualizer.save(out_path) + + +def parse_profile_log_tensor( + profile_logs_tensor: torch.Tensor, + out_path: str, + inst2opname: Any, + with_mean: bool = False, +) -> None: + """Parse a profile log tensor into a dictionary. + + Args: + profile_log_tensor: The profile log tensor. + out_path: The path to save the profile logs. + inst2opname: The mapping from instance index to operation name. + + list[tuple[str, float, int] | tuple[str, float] | str] + + Returns: + None. + """ + # Remove the extra slices for storing instructions and glb bars. + profile_logs_tensor = profile_logs_tensor[:-SLICES_FOR_TILERT_OP, :, :] + + profile_logs = profile_logs_tensor.cpu().detach().numpy() + valid_insts_logs = np.any(profile_logs != 0, axis=(1, 2)) + profile_logs = profile_logs[valid_insts_logs] + valid_blocks_logs = np.any(profile_logs != 0, axis=(0, 2)) + profile_logs = profile_logs[:, valid_blocks_logs, :] + # Return if no valid blocks logs are found. + if profile_logs.size == 0: + print("Warning: No profile logs available.") + return + profile_logs = np.transpose(profile_logs, (1, 0, 2)) + ctx_start_times = profile_logs[:, 0, 0] + profile_logs = profile_logs[:, 1:, :] + profile_logs = (profile_logs - ctx_start_times[:, None, None]).astype(np.float32) / 1.855 + + if Workbook is not None: + visualize_profile_logs(profile_logs, out_path, inst2opname, with_mean) + + +def parse_op_time(profile_logs: torch.Tensor, op_idx: int = 0, block_idx: int = 0) -> None: + data = profile_logs[op_idx, block_idx, :].cpu().numpy() + max_time = data.max() + start_time = data.min() + FREQUENCY = 1850.0 + + worker_names = [ + "controller", + " sync_io", + " io_p0", + " io_p1", + " io_p2", + " consumer", + " extra1", + " extra2", + ] + for i, worker_name in enumerate(worker_names): + if data[i * 2] != max_time: + print( + f"{worker_name}:\tstart:{(data[i * 2] - start_time) / FREQUENCY:.3f}, " + f"duration:{(data[i * 2 + 1] - data[i * 2]) / FREQUENCY:.3f}, " + f"end:{(data[i * 2 + 1] - start_time) / FREQUENCY:.3f}" + ) diff --git a/python/tilert_init.py b/python/tilert_init.py index aa9ae29..d0cd30f 100644 --- a/python/tilert_init.py +++ b/python/tilert_init.py @@ -8,33 +8,11 @@ ] -def tilert_init( - placeholder: torch.Tensor | None = None, -) -> None: - """Tilert init operation. +def tilert_init() -> None: + """Tilert init operation.""" + torch.ops.tilert.tilert_init_op() - Args: - placeholder: torch.Tensor, - A placeholder tensor. - """ - if placeholder is None: - placeholder = torch.zeros(0).to(torch.device("cuda")) - torch.ops.tilert.tilert_init_op( - placeholder, - ) - -def tilert_force_init( - placeholder: torch.Tensor | None = None, -) -> None: - """Tilert force init operation. - - Args: - placeholder: torch.Tensor, - A placeholder tensor. - """ - if placeholder is None: - placeholder = torch.zeros(0).to(torch.device("cuda")) - torch.ops.tilert.tilert_force_init_op( - placeholder, - ) +def tilert_force_init() -> None: + """Tilert force init operation.""" + torch.ops.tilert.tilert_force_init_op()