diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 154383d..f4b6702 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -36,6 +36,6 @@ jobs:
- name: Install lint dependencies
run: |
python -m pip install --upgrade pip
- pip install --no-cache-dir -r requirements-ci.txt
+ pip install --no-cache-dir -r requirements-dev.txt
- name: Run all linting checks
run: ./scripts/lint.sh
diff --git a/README.md b/README.md
index 75fdf41..8847a29 100644
--- a/README.md
+++ b/README.md
@@ -6,16 +6,29 @@
- Installation |
- Getting Started
+ Overview ·
+ Generation ·
+ MTP Generation ·
+ Installation ·
+ News
-## News
+______________________________________________________________________
-- **\[2025-12-23\]** ⚡ **[v0.1.1](https://github.com/tile-ai/TileRT/releases/tag/v0.1.1)** — Achieved ~35% reduction in end-to-end token generation latency on a single node with 8× NVIDIA B200. See our latest benchmarks for detailed measurements.
+
-- **\[2025-11-20\]** 🚀 **[v0.1.0-alpha.1](https://github.com/tile-ai/TileRT/releases/tag/v0.1.0-alpha.1)** — Initial release of TileRT for DeepSeek-V3.2-Exp, designed for **ultra-low-latency** inference. Available on [PyPI](https://pypi.org/project/tilert) and [HuggingFace](https://huggingface.co/Tile-AI/DeepSeek-V3.2-Exp-TileRT).
+## 📰 News
+
+- :fire: **2026-01-26 · [v0.1.2-alpha.1](https://github.com/tile-ai/TileRT/releases/tag/v0.1.2-alpha.1)**. **Multi-Token Prediction (MTP) lands in TileRT**. With mtp=3, we observe decoding rates up to **590 tokens/s** under synthetic workloads.
+
+- ⚡ **2025-12-23 · [v0.1.1](https://github.com/tile-ai/TileRT/releases/tag/v0.1.1)**. Achieved ~**35% further reduction** (3 ~ 4x speedup over baseline) in end-to-end token generation latency on a single node with **8× NVIDIA B200**.
+
+- 🚀 **2025-11-20 · [v0.1.0-alpha.1](https://github.com/tile-ai/TileRT/releases/tag/v0.1.0-alpha.1)**. Initial public release for **DeepSeek-V3.2-Exp**, targeting **ultra-low-latency** inference. Available on [PyPI](https://pypi.org/project/tilert) and [HuggingFace](https://huggingface.co/Tile-AI/DeepSeek-V3.2-Exp-TileRT).
+
+______________________________________________________________________
+
+
## TileRT: Pushing LLM Latency to the Limit
@@ -23,7 +36,7 @@ TileRT is an experimental project exploring core compiler techniques for serving

-Figure 1. Sequence generation with TileRT.
+Figure 1. Sequence generation with TileRT, now enhanced with Multi-Token Prediction (MTP) to accelerate inference.
We evaluated TileRT’s preliminary performance using the [**DeepSeek-V3.2-Exp**](https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp) model (without lossy optimizations such as quantization or distillation) with a batch size of 1 on 8× NVIDIA B200 GPUs. As shown in the benchmark below, TileRT demonstrates substantial improvements over existing inference systems.
@@ -39,6 +52,8 @@ To achieve this, TileRT introduces a **tile-level runtime engine**. Leveraging a
The project is actively evolving, and the underlying compiler techniques will be gradually shared with the community as they are integrated into **TileLang** and **TileScale**.
+______________________________________________________________________
+
## Installation
- [Prerequisites](#prerequisites)
@@ -145,7 +160,7 @@ docker run --gpus all -it \
tilert:v0.1.0
```
-Once inside the container, you can run the following Python script:
+Once inside the container, run the following Python script to perform text generation:
```python
from tilert.models.deepseek_v3_2.dsa_show_hands import ShowHandsGenerator
@@ -153,23 +168,28 @@ from tilert.models.deepseek_v3_2.dsa_show_hands import ShowHandsGenerator
generator: ShowHandsGenerator = ShowHandsGenerator(
max_new_tokens=1000,
model_weights_dir=MODEL_WEIGHTS_DIR,
+ with_mtp=False, # Disable MTP
)
generator.from_pretrained()
-prompt = """Tell me three jokes:
-
-1. A dad joke,
-2. A programmer joke,
-3. A joke that only makes sense if you've ever tried to train a large language model.
-Keep each joke under 15 words.
-"""
+prompt = (
+ "Tell me three jokes:\n\n"
+ "1. A dad joke,\n"
+ "2. A programmer joke,\n"
+ "3. A joke that only makes sense if you've ever tried "
+ "to train a large language model.\n"
+ "Keep each joke under 15 words."
+)
print("Prompt:", prompt)
print("Completion:")
-completion: generator.generate(prompt)
+completion = generator.generate(prompt)
```
-For instance, using the above prompt, TileRT might generate:
+For example, TileRT may generate:
+
+
+Sample output (click to expand)
```text
1. I'm afraid for the calendar. Its days are numbered.
@@ -177,7 +197,75 @@ For instance, using the above prompt, TileRT might generate:
3. My model's loss is low, but its answers are still nonsense. Overfitting.
```
-This example gives you a quick idea of the type of output you can expect from the precompiled model.
+
+
+This example demonstrates basic single-step autoregressive generation using the precompiled model.
+
+### Running the Generation Example with Multi-Token Prediction (MTP)
+
+> \[!IMPORTANT\]
+> **Weights update required for MTP.** Multi-Token Prediction (MTP) introduces additional **MTP heads** in the model weights.
+> If you were using TileRT **before v0.1.1**, please make sure you download the **latest weights** from Hugging Face.
+> Older weights do not include the required MTP heads and will fail to run when MTP is enabled.
+
+TileRT also supports Multi-Token Prediction (MTP), which allows the model to generate multiple tokens per forward pass and reduces sequential decoding depth.
+
+To better illustrate MTP behavior, we use a longer prompt that encourages extended generation:
+
+```python
+from tilert.models.deepseek_v3_2.dsa_show_hands import ShowHandsGenerator
+
+generator: ShowHandsGenerator = ShowHandsGenerator(
+ max_new_tokens=1000,
+ model_weights_dir=MODEL_WEIGHTS_DIR,
+ with_mtp=True, # Enable MTP
+)
+generator.from_pretrained()
+prompt = "Tell me 10 jokes, keep them all under 100 words."
+
+print("Prompt:", prompt)
+print("Completion:")
+completion = generator.generate(prompt)
+```
+
+When MTP is enabled, TileRT may report statistics similar to the following during generation:
+
+```text
+Accepted length: mean=2.77, min=1, max=4
+```
+
+This indicates that, on average, multiple tokens are accepted per decoding step under MTP.
+
+
+Sample output (click to expand)
+
+```text
+Of course! Here are 10 short jokes for you.
+
+1. I told my wife she was drawing her eyebrows too high. She looked surprised.
+
+2. I invented a new word: Plagiarism.
+
+3. Why don't scientists trust atoms? Because they make up everything.
+
+4. I'm reading a book on anti-gravity. It's impossible to put down.
+
+5. What's the best thing about Switzerland? I don't know, but the flag is a big plus.
+
+6. I told my computer I needed a break, and now it won't stop sending me vacation ads.
+
+7. Why did the scarecrow win an award? He was outstanding in his field.
+
+8. What do you call a fake noodle? An impasta.
+
+9. I told my suitcase there's no vacation, and now it has a lot of baggage.
+
+10. Why don't skeletons fight each other? They don't have the guts.
+```
+
+
+
+This example highlights how MTP enables TileRT to efficiently generate longer outputs by accepting multiple tokens per decoding step, while preserving the same Python API interface.
For more details, please refer to the [generation script](https://github.com/tile-ai/TileRT/blob/main/python/generate.py).
diff --git a/assets/generate.gif b/assets/generate.gif
index 13d34f3..3d73a90 100644
Binary files a/assets/generate.gif and b/assets/generate.gif differ
diff --git a/python/__init__.py b/python/__init__.py
index a1cbea0..400d6a0 100644
--- a/python/__init__.py
+++ b/python/__init__.py
@@ -40,7 +40,8 @@ def _load_library(filename: str) -> Any:
lib_path = Path(__file__).parent / filename
try:
- return ctypes.CDLL(str(lib_path))
+ torch.ops.load_library(str(lib_path))
+ return lib_path
except Exception as e:
raise RuntimeError(f"Failed to load library from {lib_path}") from e
diff --git a/python/generate.py b/python/generate.py
index d817ebb..79f61b7 100644
--- a/python/generate.py
+++ b/python/generate.py
@@ -1,6 +1,9 @@
"""Text generation script for TileRT."""
from argparse import ArgumentParser
+from typing import cast
+
+import numpy as np
from tilert.models.deepseek_v3_2.dsa_show_hands import ShowHandsGenerator
@@ -16,7 +19,16 @@ def parse_args(): # type: ignore
parser.add_argument("--max-new-tokens", type=int, default=4000, help="Max tokens to generate")
parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature")
parser.add_argument("--interactive", action="store_true")
- parser.add_argument("--fp8", action="store_true")
+ parser.add_argument(
+ "--with-mtp",
+ action="store_true",
+ help="Enable MTP (Multi-Token Prediction) for speculative decoding",
+ )
+ parser.add_argument(
+ "--use-random-weights",
+ action="store_true",
+ help="Use random weights instead of pretrained (for testing MTP without real weights)",
+ )
return parser.parse_args()
@@ -25,7 +37,15 @@ def parse_args(): # type: ignore
usage:
execute below command under tilert root directory:
+ # Standard generation with pretrained weights:
python python/generate.py --model-weights-dir "xxxx" 2>&1 | tee test.log
+
+ # MTP generation with random weights (for testing):
+ python python/generate.py --model-weights-dir "xxxx" --with-mtp \
+ --use-random-weights 2>&1 | tee test.log
+
+ # MTP generation with pretrained weights (when available):
+ python python/generate.py --model-weights-dir "xxxx" --with-mtp 2>&1 | tee test.log
"""
args = parse_args()
@@ -33,14 +53,15 @@ def parse_args(): # type: ignore
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
model_weights_dir=args.model_weights_dir,
- enable_fp8_ops=args.fp8,
+ with_mtp=args.with_mtp,
)
- # uncomment to use random weights
- # generator.init_random_weights()
-
- # use pretrained weights
- generator.from_pretrained()
+ if args.use_random_weights:
+ print("Initializing with random weights...")
+ generator.init_random_weights()
+ else:
+ print("Loading pretrained weights...")
+ generator.from_pretrained()
# simple memoryless interactive mode
if args.interactive:
@@ -53,14 +74,70 @@ def parse_args(): # type: ignore
else:
# This prompt is to test the model’s ability to follow instructions
# (in terms of quantity, type, and length) while keeping it fun.
+ print("==== Performance ====")
prompt = "Tell me 10 jokes, keep them all under 100 words."
-
print("Prompt:", prompt)
- print("Completion:")
- completion: str = generator.generate(prompt) # type: ignore[has-type]
+ all_times = []
+ all_accepted = []
+ for _iter in range(20):
+ if _iter % 5 == 0:
+ print(f"Executing iter {_iter}...")
+ results, time_list, accepted_counts = cast(
+ tuple[str, list[float], list[int]],
+ generator.generate(prompt, False), # type: ignore[has-type]
+ )
+ all_times.append(time_list)
+ all_accepted.append(accepted_counts)
+
+ if args.with_mtp:
+ for token_num in range(100, 300, 100):
+ times_to_token_num = []
+ for time_list, accepted_list in zip(all_times, all_accepted):
+ if len(time_list) > 5 and len(accepted_list) > 5:
+ times = time_list[5:]
+ accepted = accepted_list[5:]
+ cumsum_tokens = np.cumsum(accepted)
+ cumsum_times = np.cumsum(times)
+ # Find index where we reach token_num tokens
+ idx = np.searchsorted(cumsum_tokens, token_num)
+ if idx < len(cumsum_times):
+ times_to_token_num.append(cumsum_times[idx])
+ if times_to_token_num:
+ mean_total_time = np.mean(times_to_token_num)
+ mean_time = mean_total_time / token_num
+ speed = 1 / mean_time
+ out_str = (
+ f"**Perf@{token_num}: {speed:.3f} tokens/s & "
+ f"{(mean_time * 1000):.3f} ms**"
+ )
+ print(out_str)
+
+ # Print accepted tokens statistics
+ flat_accepted = [a for accepted_list in all_accepted for a in accepted_list]
+ if flat_accepted:
+ avg_accepted = sum(flat_accepted) / len(flat_accepted)
+ min_accepted = min(flat_accepted)
+ max_accepted = max(flat_accepted)
+ print(
+ f"**Accepted length: mean={avg_accepted:.2f}, "
+ f"min={min_accepted}, max={max_accepted}**"
+ )
+ else:
+ all_times_np = np.array(all_times)
+ for token_num in range(100, 300, 100):
+ mean_time = np.mean(all_times_np[..., 5:token_num])
+ speed = 1 / mean_time
+ out_str = (
+ f"**Perf@{token_num}: {speed:.3f} tokens/s & {(mean_time * 1000):.3f} ms**"
+ )
+ print(out_str)
+ print(results)
# This prompt is used to test long sequence generation
prompt = "Hi, can you tell me a very long story, with roughly 3000 words?"
print("Prompt:", prompt)
print("Completion:")
- completion = generator.generate(prompt) # type: ignore[has-type]
+ completion, _, _ = generator.generate(prompt) # type: ignore[has-type]
+
+ print("Cleaning up...")
+ generator.cleanup()
diff --git a/python/models/base.py b/python/models/base.py
index e45132d..b8a8219 100644
--- a/python/models/base.py
+++ b/python/models/base.py
@@ -9,6 +9,7 @@
from tilert import logger
from tilert.models.deepseek_config import get_rank, get_world_size
+from tilert.models.deepseek_v3_2.params import BaseParams
from tilert.models.preprocess import WeightLoader
from tilert.utils import get_profile_log_tensor
@@ -52,9 +53,10 @@ def __init__(
self.flag_enable_tilert = False
- if compute_kernel_type not in ["bf16", "fp8"]:
+ if compute_kernel_type not in ["bf16", "fp8", "fp8mma"]:
raise ValueError(
- f"Invalid compute kernel type: {compute_kernel_type}, must be one of bf16, fp8."
+ f"Invalid compute kernel type: {compute_kernel_type}, \
+ must be one of bf16, fp8, fp8mma."
)
self.compute_kernel_type = compute_kernel_type
@@ -215,7 +217,7 @@ def tilert_forward(self, *args: Any, **kwargs: Any) -> Any: # noqa: U100
raise NotImplementedError("Tilert forward not implemented")
@abstractmethod
- def to_tilert_weights(self, *args: Any, **kwargs: Any) -> None:
+ def to_tilert_weights(self, *args: Any, **kwargs: Any) -> BaseParams | None:
"""Convert weights to tilert.
Args:
diff --git a/python/models/deepseek_v3_2/__init__.py b/python/models/deepseek_v3_2/__init__.py
new file mode 100644
index 0000000..4b8633b
--- /dev/null
+++ b/python/models/deepseek_v3_2/__init__.py
@@ -0,0 +1 @@
+"""DeepSeek v3.2 model package."""
diff --git a/python/models/deepseek_v3_2/dsa_mtp_e2e_show_hands.py b/python/models/deepseek_v3_2/dsa_mtp_e2e_show_hands.py
new file mode 100644
index 0000000..3c3dcac
--- /dev/null
+++ b/python/models/deepseek_v3_2/dsa_mtp_e2e_show_hands.py
@@ -0,0 +1,158 @@
+"""DSA MTP E2E show hands for DeepSeek v3.2."""
+
+from typing import Any
+
+import torch
+
+from tilert.models.deepseek_v3_2.dsa_show_hands import ShowHandsDSALayer
+
+__all__ = [
+ "ShowHandsDsaMtpE2eLayer",
+ "dsa_mtp_e2e_show_hands_prepare_money",
+ "dsa_mtp_e2e_show_hands",
+ "dsa_mtp_e2e_show_hands_reset",
+ "dsa_mtp_e2e_show_hands_go_home",
+]
+
+
+def dsa_mtp_e2e_show_hands_prepare_money(
+ params: list[torch.Tensor],
+ temp_vars: list[torch.Tensor],
+ cache_vars: list[torch.Tensor],
+ profile_logs: torch.Tensor,
+) -> Any:
+ """Prepare money for MTP E2E show hands."""
+ return torch.ops.tilert.dsa_mtp_e2e_show_hands_prepare_money(
+ params, temp_vars, cache_vars, profile_logs
+ )
+
+
+def dsa_mtp_e2e_show_hands(draft_tokens: torch.Tensor) -> Any:
+ """Show hands with MTP E2E."""
+ return torch.ops.tilert.dsa_mtp_e2e_show_hands(draft_tokens)
+
+
+def dsa_mtp_e2e_show_hands_reset(placeholder: torch.Tensor) -> Any:
+ """Reset MTP E2E show hands."""
+ return torch.ops.tilert.dsa_mtp_e2e_show_hands_reset(placeholder)
+
+
+def dsa_mtp_e2e_show_hands_go_home(placeholder: torch.Tensor) -> Any:
+ """Cleanup MTP E2E show hands."""
+ return torch.ops.tilert.dsa_mtp_e2e_show_hands_go_home(placeholder)
+
+
+def dsa_mtp_e2e_show_hands_set_prefill_valid_tokens(
+ placeholder: torch.Tensor, num_valid_tokens: int
+) -> Any:
+ """Set the number of valid (non-padding) tokens for prefill mode.
+
+ This controls how many tokens are copied from draft_tokens to predicted_tokens
+ during prefill. Should be called before forward() when the chunk has padding.
+
+ Args:
+ placeholder: Placeholder tensor for PyTorch dispatch (not used).
+ num_valid_tokens: Number of valid tokens in the chunk (1-4).
+ """
+ return torch.ops.tilert.dsa_mtp_e2e_show_hands_set_prefill_valid_tokens(
+ placeholder, num_valid_tokens
+ )
+
+
+class ShowHandsDsaMtpE2eLayer(ShowHandsDSALayer):
+ """Show hands DSA MTP E2E layer for DeepSeek v3.2.
+
+ Inherits from ShowHandsDSALayer and adds MTP-specific functionality.
+ """
+
+ # MTP constants
+ NUM_MTP = 3
+ MTP_SEQ_LEN = NUM_MTP + 1 # 4
+
+ def __init__(
+ self,
+ max_seq_len: int,
+ with_weight_conversion: bool = True,
+ ) -> None:
+ super().__init__(
+ max_seq_len=max_seq_len,
+ model_path="",
+ with_weight_conversion=with_weight_conversion,
+ with_mtp=True,
+ )
+
+ def _get_num_cache_layers(self) -> int:
+ """Return number of cache layers (+1 for shared MTP cache)."""
+ return int(self.NUM_LAYERS) + 1
+
+ def _prepare_money(
+ self,
+ params: list[torch.Tensor],
+ intermediates: list[torch.Tensor],
+ caches: list[torch.Tensor],
+ profile_logs: torch.Tensor,
+ ) -> None:
+ """Prepare money for MTP E2E show hands."""
+ dsa_mtp_e2e_show_hands_prepare_money(params, intermediates, caches, profile_logs)
+
+ def _show_hands(self, draft_tokens: torch.Tensor) -> Any:
+ """Run MTP E2E show hands forward."""
+ return dsa_mtp_e2e_show_hands(draft_tokens.cpu())
+
+ def _reset_sequence_impl(self) -> None:
+ """Reset MTP E2E sequence."""
+ dsa_mtp_e2e_show_hands_reset(self.placeholder)
+
+ def _cleanup_impl(self) -> None:
+ """Cleanup MTP E2E resources."""
+ dsa_mtp_e2e_show_hands_go_home(self.placeholder)
+
+ def set_prefill_valid_tokens(self, num_valid_tokens: int) -> None:
+ """Set the number of valid tokens for prefill mode.
+
+ This controls how many tokens are copied from draft_tokens to predicted_tokens
+ during prefill. Should be called before forward() when the chunk has padding.
+
+ Args:
+ num_valid_tokens: Number of valid tokens in the chunk (1-4).
+ """
+ dsa_mtp_e2e_show_hands_set_prefill_valid_tokens(self.placeholder, num_valid_tokens)
+
+ def get_next_draft_tokens(self, device_id: int = 0) -> torch.Tensor:
+ """Get next_draft_tokens from the specified device.
+
+ Args:
+ device_id: Device ID to get results from.
+
+ Returns:
+ next_draft_tokens tensor of shape [1, MTP_SEQ_LEN].
+ """
+ intermediates, _, _, _ = self._get_device_result(device_id)
+ # next_draft_tokens is at index 38 (DsaTempVars::kNextDraftTokensIdx)
+ return intermediates[38]
+
+ def get_num_accepted(self, device_id: int = 0) -> int:
+ """Get number of accepted tokens from the specified device.
+
+ Args:
+ device_id: Device ID to get results from.
+
+ Returns:
+ Number of accepted tokens.
+ """
+ intermediates, _, _, _ = self._get_device_result(device_id)
+ # accepted_tokens (num_accepted) is at index 37 (DsaTempVars::kAcceptedTokensIdx)
+ return int(intermediates[37][0].item())
+
+ def get_predicted_tokens(self, device_id: int = 0) -> torch.Tensor:
+ """Get predicted_tokens from the specified device.
+
+ Args:
+ device_id: Device ID to get results from.
+
+ Returns:
+ predicted_tokens tensor containing main model predictions.
+ """
+ intermediates, _, _, _ = self._get_device_result(device_id)
+ # predicted_tokens is at index 35 (DsaTempVars::kPredictedTokensIdx)
+ return intermediates[35]
diff --git a/python/models/deepseek_v3_2/dsa_show_hands.py b/python/models/deepseek_v3_2/dsa_show_hands.py
index 146d6bf..ca781a2 100644
--- a/python/models/deepseek_v3_2/dsa_show_hands.py
+++ b/python/models/deepseek_v3_2/dsa_show_hands.py
@@ -17,27 +17,21 @@
from tilert.models.base import TileRTModule
from tilert.models.deepseek_v3_2.model_args import ModelArgs as ModelArgsV3_2
from tilert.models.deepseek_v3_2.params import (
+ BaseParams,
CacheVars,
DenseLayerParamsKeys,
+ Dsa671BModelInitializer,
IntermediateMapper,
- LLMHeadParams,
- MlaFp8Params,
- MlaParams,
- MLPFp8Params,
- MLPParams,
- MoEFp8Params,
MoELayerParamsKeys,
- MoEParams,
TempVars,
- gen_down_allreduce_fp8_params,
- gen_expert_down_allreduce_fp8_params,
- gen_unproj_o_allreduce_fp8_params,
)
from tilert.models.preprocess.weight_utils import (
+ DownAllreduceWeightsConverter,
ExpertSelectUpGateSiLUWeightsConverter,
RMSNormHeadProjWeightsConverter,
- RMSNormProjQAKVAKIRopeWeightsConverter,
+ RMSNormProjQAKVAKIWeightsConverter,
RMSNormUpGateSiLUWeightsConverter,
+ UnProjOAllreduceWeightsConverter,
)
from tilert.models.utils import precompute_freqs_cis
from tilert.tilert_init import tilert_init
@@ -47,6 +41,45 @@
"ShowHandsGenerator",
]
+# MTP layer ID constant
+MTP_LAYER_ID = 61
+
+# MTP params keys order (for layer 61)
+MTPPreprocessParamsKeys = [
+ "embedding_rmsnorm_gamma",
+ "hidden_rmsnorm_gamma",
+ "eh_proj_weights",
+]
+
+MTPMlaParamsKeys = [
+ "x_rmsnorm_gamma",
+ "qkv_wa_weights",
+ "qkv_wa_scales",
+ "k_weights",
+ "k_bias",
+ "q_rmsnorm_gamma",
+ "q_wb_weights",
+ "q_wb_scales",
+ "id_score_weights",
+ "wkv_b1_weights",
+ "wkv_b1_scales",
+ "kv_rmsnorm_gamma",
+ "wkv_b2_weights",
+ "wkv_b2_scales",
+ "unproj_weights",
+ "unproj_scales",
+]
+
+MTPMoeParamsKeys = [
+ "unproj_o_gamma",
+ "exp_proj_weights",
+ "exp_bias",
+ "exp_upgate_weights",
+ "exp_upgate_scales",
+ "exp_down_weights",
+ "exp_down_scales",
+]
+
def stats_time(time_list: list[float], title: str) -> None:
if len(time_list) > 0:
@@ -62,16 +95,15 @@ def stats_time(time_list: list[float], title: str) -> None:
def dsa_show_hands_prepare_money(
- enable_fused_op: bool,
- enable_fp8_ops: bool,
params: list[torch.Tensor],
temp_vars: list[torch.Tensor],
cache_vars: list[torch.Tensor],
profile_logs: torch.Tensor,
+ forward_max_seq_len: int,
) -> Any:
"""Prepare money for show hands"""
return torch.ops.tilert.dsa_show_hands_prepare_money(
- enable_fused_op, enable_fp8_ops, params, temp_vars, cache_vars, profile_logs
+ params, temp_vars, cache_vars, profile_logs, forward_max_seq_len
)
@@ -90,48 +122,162 @@ def dsa_show_hands_go_home(placeholder: torch.Tensor) -> Any:
return torch.ops.tilert.dsa_show_hands_go_home(placeholder)
+# Put dirty conversion code here.
+# TODO: better way to handle conversion.
def _convert_weights_on_demand(
state_dicts: dict[str, torch.Tensor],
+ skip_mtp_layer: bool = False,
) -> dict[str, torch.Tensor]:
- """Convert weights on demand"""
+ """Convert weights on demand.
+
+ Args:
+ state_dicts: Dictionary of weights to convert.
+ skip_mtp_layer: If True, skip layer 61 (MTP layer) weights.
+ """
res_dicts = {}
for key, value in state_dicts.items():
+ # Skip layer 61 (MTP layer) MLA/MoE/preprocess weights if requested,
+ # but NOT lm_head and model.norm.weight (used by main model head)
+ if (
+ skip_mtp_layer
+ and f"layer_{MTP_LAYER_ID}_" in key
+ and "lm_head.weight" not in key
+ and "model.norm.weight" not in key
+ ):
+ res_dicts[key] = value
+ continue
+
if "qkv_wa_weights" in key: # first op
weight_key = key
scale_key = key.replace("qkv_wa_weights", "qkv_wa_scales")
gamma_key = key.replace("qkv_wa_weights", "x_rmsnorm_gamma")
- common_weights = RMSNormProjQAKVAKIRopeWeightsConverter.tilert_to_common(
+ common_weights = RMSNormProjQAKVAKIWeightsConverter.tilert_to_common(
state_dicts[weight_key],
state_dicts[scale_key],
state_dicts[gamma_key],
)
conv_weights = (
- RMSNormProjQAKVAKIRopeWeightsConverter.common_to_tilert_native_bf16_warp_gemv(
+ RMSNormProjQAKVAKIWeightsConverter.common_to_tilert_native_bf16_warp_gemv(
*common_weights
)
)
res_dicts[key] = conv_weights[0]
+ elif "unproj_weights" in key: # unprojo_allreduce op
+ weight_key = key
+ scale_key = key.replace("unproj_weights", "unproj_scales")
+ weights, scales = UnProjOAllreduceWeightsConverter.tilert_to_tilert_112sm_mma(
+ state_dicts[weight_key],
+ state_dicts[scale_key],
+ )
+ res_dicts[weight_key] = weights
+ res_dicts[scale_key] = scales
+ state_dicts[weight_key] = None
+ elif "unproj_scales" in key: # skip unprojo_allreduce op:: scales
+ pass
elif "exp_upgate_weights" in key: # expert select up gate silu op
weight_key = key
scale_key = key.replace("exp_upgate_weights", "exp_upgate_scales")
weights_and_scales = ExpertSelectUpGateSiLUWeightsConverter.tilert_to_tilert_144sm_mma(
- state_dicts[weight_key],
- state_dicts[scale_key],
+ state_dicts[weight_key], state_dicts[scale_key]
)
res_dicts[key] = weights_and_scales
+ state_dicts[weight_key] = None
+
elif "upgate_weights" in key: # rmsnorm up gate silu op
weight_key = key
scale_key = key.replace("upgate_weights", "upgate_scales")
- weights_and_scales = RMSNormUpGateSiLUWeightsConverter.tilert_to_tilert_144sm(
+ weights_and_scales = RMSNormUpGateSiLUWeightsConverter.tilert_to_tilert_144sm_mma(
state_dicts[weight_key],
state_dicts[scale_key],
)
res_dicts[key] = weights_and_scales
+ state_dicts[weight_key] = None
+ elif "down_weights" in key: # expert down allreduce op
+ weight_key = key
+ scale_key = key.replace("down_weights", "down_scales")
+ weights_swizzled, scales = DownAllreduceWeightsConverter.tilert_to_tilert_mma(
+ state_dicts[weight_key],
+ state_dicts[scale_key],
+ )
+ res_dicts[weight_key] = weights_swizzled
+ res_dicts[scale_key] = scales
+ state_dicts[weight_key] = None
elif "lm_head.weight" in key: # head projection weights
weights = RMSNormHeadProjWeightsConverter.tilert_to_tilert_native_bf16_warp_gemv(
state_dicts[key]
)
res_dicts[key] = weights
+ state_dicts[key] = None
+ else:
+ res_dicts[key] = value
+
+ return res_dicts
+
+
+def _convert_mtp_weights_on_demand(
+ state_dicts: dict[str, torch.Tensor],
+) -> dict[str, torch.Tensor]:
+ """Convert MTP layer weights on demand to TileRT optimized format.
+
+ This applies conversions specifically for MTP layer weights (layer 61).
+ Only processes keys that contain 'layer_61_'.
+ Note: lm_head and model.norm.weight are reused from main model (already converted).
+ """
+ res_dicts = {}
+ for key, value in state_dicts.items():
+ # Only process layer 61 (MTP layer) weights
+ if f"layer_{MTP_LAYER_ID}_" not in key:
+ res_dicts[key] = value
+ continue
+
+ # Skip lm_head and model.norm.weight - they're reused from main model
+ # and already converted by _convert_weights_on_demand
+ if "lm_head.weight" in key or "model.norm.weight" in key:
+ res_dicts[key] = value
+ continue
+
+ if "qkv_wa_weights" in key: # first op
+ weight_key = key
+ scale_key = key.replace("qkv_wa_weights", "qkv_wa_scales")
+ gamma_key = key.replace("qkv_wa_weights", "x_rmsnorm_gamma")
+ common_weights = RMSNormProjQAKVAKIWeightsConverter.tilert_to_common(
+ state_dicts[weight_key],
+ state_dicts[scale_key],
+ state_dicts[gamma_key],
+ )
+ conv_weights = (
+ RMSNormProjQAKVAKIWeightsConverter.common_to_tilert_native_bf16_warp_gemv(
+ *common_weights
+ )
+ )
+ res_dicts[key] = conv_weights[0]
+ elif "unproj_weights" in key: # unproj_o_allreduce op
+ weight_key = key
+ scale_key = key.replace("unproj_weights", "unproj_scales")
+ weights, scales = UnProjOAllreduceWeightsConverter.tilert_to_tilert_112sm_mma(
+ state_dicts[weight_key],
+ state_dicts[scale_key],
+ )
+ res_dicts[weight_key] = weights
+ res_dicts[scale_key] = scales
+ elif "unproj_scales" in key: # skip - already processed with unproj_weights
+ pass
+ elif "exp_upgate_weights" in key: # expert select up gate silu op
+ weight_key = key
+ scale_key = key.replace("exp_upgate_weights", "exp_upgate_scales")
+ weights_and_scales = ExpertSelectUpGateSiLUWeightsConverter.tilert_to_tilert_144sm_mma(
+ state_dicts[weight_key], state_dicts[scale_key]
+ )
+ res_dicts[key] = weights_and_scales
+ elif "exp_down_weights" in key: # expert down allreduce op
+ weight_key = key
+ scale_key = key.replace("exp_down_weights", "exp_down_scales")
+ weights_swizzled, scales = DownAllreduceWeightsConverter.tilert_to_tilert_mma(
+ state_dicts[weight_key],
+ state_dicts[scale_key],
+ )
+ res_dicts[weight_key] = weights_swizzled
+ res_dicts[scale_key] = scales
else:
res_dicts[key] = value
@@ -149,11 +295,12 @@ def __init__(
self,
max_seq_len: int,
model_path: str = "",
- enable_fp8_ops: bool = False,
+ with_weight_conversion: bool = True,
+ with_mtp: bool = False,
) -> None:
super().__init__()
self.hidden_size = 7168
- self.seq_len = 1
+ self.forward_max_seq_len = 4 # max supported seq_len per forward
self.batch_size = 1
self.num_heads = 16
@@ -179,8 +326,8 @@ def __init__(
self.num_devices = 8
self.model_path = model_path
-
- self.enable_fp8_ops = enable_fp8_ops
+ self.with_weight_conversion = with_weight_conversion
+ self.with_mtp = with_mtp
self.multi_devices_results: list[DeviceResult | None] = [None] * torch.cuda.device_count()
@@ -196,263 +343,87 @@ def __init__(
self.placeholder = torch.zeros(1, 1, dtype=torch.int32, device="cpu")
- def golden_forward(self) -> None:
- raise NotImplementedError("golden_forward not implemented")
-
- def tilert_forward(self) -> None:
- raise NotImplementedError("tilert_forward not implemented")
-
- def to_tilert_weights(self) -> None:
- raise NotImplementedError("to_tilert_weights not implemented")
-
- def register_weights_and_scales(
- self, dim1: int, dim2: int, device: torch.device
- ) -> tuple[torch.Tensor, torch.Tensor]:
- block_size = 128
- weights_dims = (dim1, dim2)
- weights = torch.randn(weights_dims, dtype=torch.float16, device=device).to(
- torch.float8_e4m3fn
- )
- scales = torch.randn(
- (dim1, dim2 // block_size),
- dtype=torch.bfloat16,
- device=device,
- )
- return weights, scales
-
- def init_mla_params(self, device: torch.device, dev_attrs: dict) -> MlaParams:
- mat_attrs = {
- "device": device,
- "dtype": torch.float16,
- }
+ def _get_num_cache_layers(self) -> int:
+ """Return number of cache layers. Override in subclass for MTP."""
+ return self.NUM_LAYERS
- qkv_dim = self.q_dim + self.kv_dim + self.k_pe_dim + 128
- x_rmsnorm_gamma_shape = (self.hidden_size,)
- q_wb_shape = ((self.q_pe_lora_dim + self.q_nope_dim + 512) * self.num_heads, self.q_dim)
- wkv_b1_shape = (self.num_heads, self.q_pe_dim, self.v_head_dim)
- wkv_b2_shape = (self.num_heads, self.v_head_dim, self.kv_dim)
- wkv_b2_scales_shape = (self.num_heads, self.v_head_dim // 128, self.kv_dim // 128)
- unproj_w_shape = (self.hidden_size, self.num_heads * self.v_head_dim)
-
- x_rmsnorm_gamma = torch.randn(x_rmsnorm_gamma_shape, dtype=torch.float32, device=device)
- qkv_wa_weights, _ = self.register_weights_and_scales(qkv_dim, self.hidden_size, device)
- qkv_wa_scales = torch.randn((130, 64), dtype=torch.bfloat16, device=device)
- k_weights = torch.randn(128, dtype=torch.float32, device=device)
- k_bias = torch.randn(128, dtype=torch.float32, device=device)
- q_rmsnorm_gamma = torch.randn(self.q_dim, dtype=torch.float32, device=device)
- q_wb_weights, _ = self.register_weights_and_scales(*q_wb_shape, device)
- q_wb_scales = torch.randn((448, 12), dtype=torch.bfloat16, device=device)
- id_score_weights = torch.randn(64, self.hidden_size, dtype=torch.bfloat16, device=device)
- wkv_b1_weights = torch.randn(wkv_b1_shape, **mat_attrs).to(torch.float8_e4m3fn)
- wkv_b1_scales = torch.randn((16, 8, 1), dtype=torch.bfloat16, device=device)
- kv_rmsnorm_gamma = torch.randn(self.kv_dim, dtype=torch.float32, device=device)
- wkv_b2_weights = torch.randn(wkv_b2_shape, **mat_attrs).to(torch.float8_e4m3fn)
- wkv_b2_scales = torch.randn(wkv_b2_scales_shape, **dev_attrs)
- unproj_weights = torch.randn(unproj_w_shape, **mat_attrs).to(torch.float8_e4m3fn)
- unproj_scales = torch.randn((896, self.num_heads * self.v_head_dim // 128), **dev_attrs)
-
- return MlaParams(
- x_rmsnorm_gamma=x_rmsnorm_gamma,
- qkv_wa_weights=qkv_wa_weights,
- qkv_wa_scales=qkv_wa_scales,
- k_weights=k_weights,
- k_bias=k_bias,
- q_rmsnorm_gamma=q_rmsnorm_gamma,
- q_wb_weights=q_wb_weights,
- q_wb_scales=q_wb_scales,
- id_score_weights=id_score_weights,
- wkv_b1_weights=wkv_b1_weights,
- wkv_b1_scales=wkv_b1_scales,
- kv_rmsnorm_gamma=kv_rmsnorm_gamma,
- wkv_b2_weights=wkv_b2_weights,
- wkv_b2_scales=wkv_b2_scales,
- unproj_weights=unproj_weights,
- unproj_scales=unproj_scales,
+ def _prepare_money(
+ self,
+ params: list[torch.Tensor],
+ intermediates: list[torch.Tensor],
+ caches: list[torch.Tensor],
+ profile_logs: torch.Tensor,
+ ) -> None:
+ """Prepare money for show hands. Override in subclass for MTP."""
+ dsa_show_hands_prepare_money(
+ params, intermediates, caches, profile_logs, self.forward_max_seq_len
)
- def init_mlp_params(self, device: torch.device, dev_attrs: dict) -> MLPParams:
- mat_attrs = {
- "device": device,
- "dtype": torch.float16,
- }
+ def _show_hands(self, token_id: torch.Tensor) -> Any:
+ """Run show hands forward. Override in subclass for MTP."""
+ return dsa_show_hands(token_id.cpu())
- exp_upgate_w_shape = (9, self.exp_dims * 2, self.hidden_size)
- exp_upgate_s_shape = (9, self.exp_dims * 2 // 128, 64)
- exp_down_w_shape = (9, self.hidden_size, self.exp_dims)
- exp_down_s_shape = (9, 1024, self.exp_dims // 128)
-
- unproj_o_gamma = torch.randn(self.hidden_size, dtype=torch.float32, device=device)
- upgate_weights = torch.randn(exp_upgate_w_shape, **mat_attrs).to(torch.float8_e4m3fn)
- upgate_scales = torch.randn(exp_upgate_s_shape, **dev_attrs)
- down_weights = torch.randn(exp_down_w_shape, **mat_attrs).to(torch.float8_e4m3fn)
- down_scales = torch.randn(exp_down_s_shape, **dev_attrs)
-
- return MLPParams(
- unproj_o_gamma,
- upgate_weights,
- upgate_scales,
- down_weights,
- down_scales,
- )
+ def _reset_sequence_impl(self) -> None:
+ """Reset sequence implementation. Override in subclass for MTP."""
+ dsa_show_hands_reset(self.placeholder)
- def init_moe_params(self, device: torch.device, dev_attrs: dict) -> MoEParams:
- mat_attrs = {
- "device": device,
- "dtype": torch.float16,
- }
+ def _cleanup_impl(self) -> None:
+ """Cleanup implementation. Override in subclass for MTP."""
+ dsa_show_hands_go_home(self.placeholder)
- exp_upgate_w_shape = (self.n_routed_experts + 1, self.exp_dims * 2, self.hidden_size)
- exp_upgate_s_shape = (self.n_routed_experts + 1, self.exp_dims * 2 // 128, 64)
- exp_down_w_shape = (self.n_routed_experts + 1, self.hidden_size, self.exp_dims)
- exp_down_s_shape = (self.n_routed_experts + 1, 1024, self.exp_dims // 128)
-
- unproj_o_gamma = torch.randn(self.hidden_size, dtype=torch.float32, device=device)
- exp_proj_weights = torch.randn((self.n_routed_experts, self.hidden_size), **dev_attrs)
- exp_bias = torch.randn(self.n_routed_experts, dtype=torch.float32, device=device)
- exp_upgate_weights = torch.randn(exp_upgate_w_shape, **mat_attrs).to(torch.float8_e4m3fn)
- exp_upgate_scales = torch.randn(exp_upgate_s_shape, **dev_attrs)
- exp_down_weights = torch.randn(exp_down_w_shape, **mat_attrs).to(torch.float8_e4m3fn)
- exp_down_scales = torch.randn(exp_down_s_shape, **dev_attrs)
-
- return MoEParams(
- unproj_o_gamma,
- exp_proj_weights,
- exp_bias,
- exp_upgate_weights,
- exp_upgate_scales,
- exp_down_weights,
- exp_down_scales,
- )
+ def golden_forward(self) -> None:
+ raise NotImplementedError("golden_forward not implemented")
- def init_llm_head_params(self, device: torch.device, dev_attrs: dict) -> LLMHeadParams:
- del dev_attrs
- hidden_rms_gamma_shape = (self.hidden_size,)
- head_proj_weights_shape = (self.vocab_size, self.hidden_size)
+ def tilert_forward(self) -> None:
+ raise NotImplementedError("tilert_forward not implemented")
- hidden_rms_gamma = torch.randn(hidden_rms_gamma_shape, dtype=torch.float32, device=device)
- head_proj_weights = torch.randn(
- head_proj_weights_shape, dtype=torch.bfloat16, device=device
- )
- return LLMHeadParams(
- hidden_rms_gamma,
- head_proj_weights,
- )
+ def to_tilert_weights(self) -> BaseParams:
+ raise NotImplementedError("to_tilert_weights not implemented")
def get_mla_moe_layer_params_dict(
self, layer_id: int, device: torch.device, dev_attrs: dict
) -> dict[str, torch.Tensor]:
- mla_params_dict = self.init_mla_params(device, dev_attrs).to_dict(layer_id, device)
- moe_params_dict = self.init_moe_params(device, dev_attrs).to_dict(layer_id, device)
-
+ del dev_attrs
+ dsa_671b_model = Dsa671BModelInitializer(
+ torch.device(device),
+ with_weight_conversion=self.with_weight_conversion,
+ )
return {
- **mla_params_dict,
- **moe_params_dict,
+ **dsa_671b_model.init_mla_params().to_dict(layer_id, device),
+ **dsa_671b_model.init_moe_params().to_dict(layer_id, device),
}
def get_mla_mlp_layer_params_dict(
self, layer_id: int, device: torch.device, dev_attrs: dict
) -> dict[str, torch.Tensor]:
- mla_params_dict = self.init_mla_params(device, dev_attrs).to_dict(layer_id, device)
- mlp_params_dict = self.init_mlp_params(device, dev_attrs).to_dict(layer_id, device)
+ del dev_attrs
+ dsa_671b_model = Dsa671BModelInitializer(
+ torch.device(device),
+ with_weight_conversion=self.with_weight_conversion,
+ )
return {
- **mla_params_dict,
- **mlp_params_dict,
+ **dsa_671b_model.init_mla_params().to_dict(layer_id, device),
+ **dsa_671b_model.init_mlp_params().to_dict(layer_id, device),
}
def get_llm_head_layer_params_dict(
self, layer_id: int, device: torch.device, dev_attrs: dict
) -> dict[str, torch.Tensor]:
- return {**self.init_llm_head_params(device, dev_attrs).to_dict(layer_id, device)}
+ del dev_attrs
+ dsa_671b_model = Dsa671BModelInitializer(
+ torch.device(device),
+ with_weight_conversion=self.with_weight_conversion,
+ )
+ return {**dsa_671b_model.init_llm_head_params().to_dict(layer_id, device)}
def get_temp_vars(self, device: torch.device, dev_attrs: dict) -> TempVars:
- q = torch.zeros(self.batch_size, self.seq_len, self.q_dim, **dev_attrs)
- kv = torch.zeros(self.batch_size, self.seq_len, self.kv_dim, **dev_attrs)
- q_pe = torch.zeros(
- self.batch_size, self.seq_len, self.num_heads, self.q_pe_lora_dim, **dev_attrs
- )
- ki = torch.zeros(self.batch_size, self.seq_len, 128, **dev_attrs)
- q_nope_down = torch.zeros(
- self.batch_size, self.seq_len, self.num_heads, self.v_head_dim, **dev_attrs
- )
- q_nope = torch.zeros(
- self.batch_size, self.seq_len, self.num_heads, self.q_pe_dim, **dev_attrs
- )
- iq = torch.zeros(self.batch_size, self.seq_len, 64, 128, **dev_attrs)
- iq_rt = torch.zeros(self.batch_size, self.seq_len, 64, 128, **dev_attrs)
- idx_score = torch.zeros(self.batch_size, self.seq_len, 64, **dev_attrs)
- idx_logits = torch.zeros(
- self.batch_size, self.seq_len, self.max_seq_len, dtype=torch.float32, device=device
- )
- idx_sels = torch.zeros(self.batch_size, 2048, dtype=torch.int32, device=device)
- o = torch.zeros(self.batch_size, self.seq_len, self.num_heads, self.kv_dim, **dev_attrs)
- o_acc = torch.zeros(
- self.batch_size,
- self.num_heads,
- 128,
- self.kv_dim,
- dtype=torch.float32,
- device=device,
- )
- o_lse = torch.empty(self.batch_size, self.num_heads, dtype=torch.float32, device=device)
- o_lse_acc = torch.empty(
- self.batch_size, self.num_heads, 128, dtype=torch.float32, device=device
- )
- proj_o = torch.zeros(
- self.batch_size, self.seq_len, self.num_heads, self.v_head_dim, **dev_attrs
- )
- unproj_o = torch.zeros(self.batch_size, self.seq_len, self.hidden_size, **dev_attrs)
- scores = torch.zeros(
- self.batch_size, self.seq_len, self.n_routed_experts, dtype=torch.float32, device=device
- )
- x_mlp_in = torch.zeros(self.batch_size, self.seq_len, self.hidden_size, **dev_attrs)
- exp_up_gate = torch.zeros(
- self.batch_size, self.seq_len, self.n_activate_experts + 1, self.exp_dims, **dev_attrs
- )
- sel_probs = torch.zeros(
- self.batch_size,
- self.seq_len,
- self.n_activate_experts,
- dtype=torch.float32,
- device=device,
- )
- sel_indices = torch.zeros(
- self.batch_size, self.seq_len, self.n_activate_experts, dtype=torch.int32, device=device
- )
- exp_out = torch.zeros(self.batch_size, self.seq_len, self.hidden_size, **dev_attrs)
- x_rmsnorm = torch.zeros(self.batch_size, self.seq_len, self.hidden_size, **dev_attrs)
- logits_out = torch.zeros(
- self.batch_size, self.vocab_size, dtype=torch.float32, device=device
- )
- token_out = torch.zeros(self.batch_size, 1, dtype=torch.int32, device=device)
-
- return TempVars(
- q,
- kv,
- ki,
- q_nope_down,
- q_pe,
- iq,
- iq_rt,
- idx_score,
- idx_logits,
- idx_sels,
- q_nope,
- o,
- o_acc,
- o_lse,
- o_lse_acc,
- proj_o,
- unproj_o,
- scores,
- x_mlp_in,
- exp_up_gate,
- sel_probs,
- sel_indices,
- exp_out,
- x_rmsnorm,
- logits_out,
- token_out,
+ del dev_attrs
+ dsa_671b_model = Dsa671BModelInitializer(
+ torch.device(device),
+ with_weight_conversion=self.with_weight_conversion,
+ with_mtp=self.with_mtp,
)
+ return dsa_671b_model.acquire_temp_vars()
def get_cache_vars(self, device: torch.device, dev_attrs: dict) -> CacheVars:
del dev_attrs
@@ -490,16 +461,22 @@ def get_weight_files(self, weight_map: dict[str, str], device_id: int) -> list[s
return weight_files
- def load_embedding_weights(
- self, model_path: str, total_shards: int, device_id: int
- ) -> torch.Tensor:
+ def load_embedding_weights(self, model_path: str, device_id: int) -> torch.Tensor:
"""Load the embedding weights for the given device."""
- # the first shard is for embedding
- weight_prefix = "model.safetensors-00001-of"
- embed_weights_file = f"{weight_prefix}-{total_shards:05d}.safetensors"
+ # Look up the embedding file from index.json instead of hardcoding
+ index_file = os.path.join(model_path, "model.safetensors.index.json")
+ with open(index_file, encoding="utf-8") as f:
+ weights_index = json.load(f)
+ weight_map = weights_index["weight_map"]
+
+ embed_key = "model.embed_tokens.weight"
+ if embed_key not in weight_map:
+ raise ValueError(f"Embedding weight {embed_key} not found in index.json")
+
+ embed_weights_file = weight_map[embed_key]
embed_weights_file_path = os.path.join(model_path, embed_weights_file)
state_dict = load_file(embed_weights_file_path, device=f"cuda:{device_id}")
- return state_dict["model.embed_tokens.weight"]
+ return state_dict[embed_key]
def get_total_shards(self, model_path: str) -> int:
"""Get the total number of shards by counting safetensors files in the directory."""
@@ -510,7 +487,6 @@ def _init_weights(self, model_path: str | None) -> None:
"""Load the model weights from the given path or generate random weights."""
def _load_state_dicts(model_path: str, dev_attrs: dict) -> dict[str, torch.Tensor]:
- total_shards = self.get_total_shards(model_path)
device_id = dev_attrs["device"]
index_file = "model.safetensors.index.json"
with open(os.path.join(model_path, index_file), encoding="utf-8") as f:
@@ -525,7 +501,7 @@ def _load_state_dicts(model_path: str, dev_attrs: dict) -> dict[str, torch.Tenso
)
logger.info(f"Loaded weights from {weight_file} for device {device_id}")
state_dicts.update(state_dict)
- embed_weights = self.load_embedding_weights(model_path, total_shards, device_id)
+ embed_weights = self.load_embedding_weights(model_path, device_id)
state_dicts["model.embed_tokens.weight"] = embed_weights
return state_dicts
@@ -541,12 +517,14 @@ def _gen_state_dicts_with_random_weights(dev_attrs: dict) -> dict[str, torch.Ten
self.get_mla_moe_layer_params_dict(layer_id, device_id, dev_attrs)
)
state_dicts.update(self.get_llm_head_layer_params_dict(61, device_id, dev_attrs))
- state_dicts["model.embed_tokens.weight"] = torch.randn(
- self.vocab_size_full, self.hidden_size, **dev_attrs
+ dsa_671b_model = Dsa671BModelInitializer(
+ torch.device(device_id),
+ with_weight_conversion=self.with_weight_conversion,
)
+ state_dicts.update(dsa_671b_model.init_embedding_params().to_dict(device_id))
return state_dicts
- def __load_weights(device_id: int, model_path: str) -> None:
+ def __load_weights(device_id: int, model_path: str | None) -> None:
intermediates: list[torch.Tensor] = []
caches: list[torch.Tensor] = []
params: list[torch.Tensor] = []
@@ -562,42 +540,32 @@ def __load_weights(device_id: int, model_path: str) -> None:
device_id, dev_attrs
).generate_params_with_continuous_storage(device_id)
)
- for _ in range(self.NUM_LAYERS):
+ num_cache_layers = self._get_num_cache_layers()
+ for _ in range(num_cache_layers):
caches.extend(self.get_cache_vars(device_id, dev_attrs).get_params())
logger.info(f"Created intermediates and caches for device {device_id}")
- if model_path and os.path.exists(model_path):
+ load_from_path = bool(model_path and os.path.exists(model_path))
+ if load_from_path:
+ assert model_path is not None # Type narrowing for mypy
state_dicts = _load_state_dicts(model_path, dev_attrs)
else:
state_dicts = _gen_state_dicts_with_random_weights(dev_attrs)
- # Do necessary weight conversions here
- state_dicts = _convert_weights_on_demand(state_dicts)
+ # Do necessary weight conversions only for loaded weights.
+ # Skip MTP layer (layer 61) conversion here if with_mtp is True,
+ # as it will be handled separately.
+ if load_from_path:
+ state_dicts = _convert_weights_on_demand(
+ state_dicts, skip_mtp_layer=self.with_mtp
+ )
for layer_id in range(self.NUM_DENSE_LAYERS):
+ # Each layer has its dedicated cache
for param in DenseLayerParamsKeys:
key_name = f"layer_{layer_id}_{param}_dev_{device_id}"
if key_name not in state_dicts:
raise ValueError(f"Weight {key_name} not found")
params.append(state_dicts[key_name])
- unproj_o_allreduce_fp8_params = torch.empty(0, device=device_id)
- if self.enable_fp8_ops:
- unproj_o_allreduce_fp8_params = gen_unproj_o_allreduce_fp8_params(
- state_dicts[f"layer_{layer_id}_unproj_weights_dev_{device_id}"],
- state_dicts[f"layer_{layer_id}_unproj_scales_dev_{device_id}"],
- )
- params.extend(MlaFp8Params(unproj_o_allreduce_fp8_params).get_params())
- upgate_all_reduce_fp8_params = torch.empty(0, device=device_id)
- down_all_reduce_fp8_params = torch.empty(0, device=device_id)
- if self.enable_fp8_ops:
- down_all_reduce_fp8_params = gen_down_allreduce_fp8_params(
- state_dicts[f"layer_{layer_id}_down_weights_dev_{device_id}"],
- state_dicts[f"layer_{layer_id}_down_scales_dev_{device_id}"],
- )
- params.extend(
- MLPFp8Params(
- upgate_all_reduce_fp8_params, down_all_reduce_fp8_params
- ).get_params()
- )
for layer_id in range(3, 3 + self.NUM_MOE_LAYERS):
# Each layer has its dedicated cache
@@ -606,26 +574,8 @@ def __load_weights(device_id: int, model_path: str) -> None:
if key_name not in state_dicts:
raise ValueError(f"Weight {key_name} not found")
params.append(state_dicts[key_name])
- unproj_o_allreduce_fp8_params = torch.empty(0, device=device_id)
- if self.enable_fp8_ops:
- unproj_o_allreduce_fp8_params = gen_unproj_o_allreduce_fp8_params(
- state_dicts[f"layer_{layer_id}_unproj_weights_dev_{device_id}"],
- state_dicts[f"layer_{layer_id}_unproj_scales_dev_{device_id}"],
- )
- params.extend(MlaFp8Params(unproj_o_allreduce_fp8_params).get_params())
- expert_up_gate_fp8_params = torch.empty(0, device=device_id)
- expert_down_all_reduce_fp8_params = torch.empty(0, device=device_id)
- if self.enable_fp8_ops:
- expert_down_all_reduce_fp8_params = gen_expert_down_allreduce_fp8_params(
- state_dicts[f"layer_{layer_id}_exp_down_weights_dev_{device_id}"],
- state_dicts[f"layer_{layer_id}_exp_down_scales_dev_{device_id}"],
- )
- params.extend(
- MoEFp8Params(
- expert_up_gate_fp8_params, expert_down_all_reduce_fp8_params
- ).get_params()
- )
+ # heads
head = f"layer_61_lm_head.weight_dev_{device_id}"
head_norm = f"layer_61_model.norm.weight_dev_{device_id}"
if head not in state_dicts:
@@ -635,12 +585,77 @@ def __load_weights(device_id: int, model_path: str) -> None:
params.append(state_dicts[head_norm])
params.append(state_dicts[head])
+ # embed_weights = self.load_embedding_weights(total_shards, device_id)
params.append(state_dicts["model.embed_tokens.weight"])
# RoPE frequencies
freqs_cis = self._gen_freqs_cis()
params.extend([freqs_cis.to(device_id)])
+ # Add MTP-specific params when with_mtp is True
+ if self.with_mtp:
+ if load_from_path:
+ # Load real MTP weights from state_dicts
+ # Convert MTP-specific weights (layer 61)
+ state_dicts = _convert_mtp_weights_on_demand(state_dicts)
+
+ # MTP params order (matching C++ register_mtp.cu):
+ # 1. LlmPreprocessModule: 2 (embedding + freqs_cis)
+ # 2. MtpPreProcessLayer: 3 (embedding_rmsnorm_gamma,
+ # hidden_rmsnorm_gamma, eh_proj_weights)
+ # 3. MoeBlock: 23 (MLA 16 + MOE 7)
+ # 4. LlmHeadModule: 2 (hidden_rms_gamma, head_proj_weights)
+
+ # 1. Embedding params (2)
+ params.append(state_dicts["model.embed_tokens.weight"])
+ mtp_freqs_cis = self._gen_freqs_cis()
+ params.append(mtp_freqs_cis.to(device_id))
+
+ # 2. MTP preprocess params (3)
+ for key in MTPPreprocessParamsKeys:
+ full_key = f"layer_{MTP_LAYER_ID}_{key}_dev_{device_id}"
+ if full_key not in state_dicts:
+ raise ValueError(f"MTP weight {full_key} not found")
+ params.append(state_dicts[full_key])
+
+ # 3. MLA params (16) + MOE params (7) = MoeBlock (23)
+ for key in MTPMlaParamsKeys:
+ full_key = f"layer_{MTP_LAYER_ID}_{key}_dev_{device_id}"
+ if full_key not in state_dicts:
+ raise ValueError(f"MTP weight {full_key} not found")
+ params.append(state_dicts[full_key])
+
+ for key in MTPMoeParamsKeys:
+ full_key = f"layer_{MTP_LAYER_ID}_{key}_dev_{device_id}"
+ if full_key not in state_dicts:
+ raise ValueError(f"MTP weight {full_key} not found")
+ params.append(state_dicts[full_key])
+
+ # 4. LLM head params (2)
+ mtp_head_norm = f"layer_{MTP_LAYER_ID}_model.norm.weight_dev_{device_id}"
+ mtp_head = f"layer_{MTP_LAYER_ID}_lm_head.weight_dev_{device_id}"
+ if mtp_head_norm not in state_dicts:
+ raise ValueError(f"MTP weight {mtp_head_norm} not found")
+ if mtp_head not in state_dicts:
+ raise ValueError(f"MTP weight {mtp_head} not found")
+ params.append(state_dicts[mtp_head_norm])
+ params.append(state_dicts[mtp_head])
+
+ logger.info(f"Loaded real MTP weights for device {device_id}")
+ else:
+ # Use random weights for MTP
+ dsa_671b_model = Dsa671BModelInitializer(
+ torch.device(device_id),
+ with_weight_conversion=self.with_weight_conversion,
+ with_mtp=True,
+ )
+ # MTP needs: embedding, mtp_preprocess, mla, moe, llm_head params
+ params.extend(dsa_671b_model.init_embedding_params().get_params())
+ params.extend(dsa_671b_model.init_mtp_preprocess_params().get_params())
+ params.extend(dsa_671b_model.init_mla_params().get_params())
+ params.extend(dsa_671b_model.init_moe_params().get_params())
+ params.extend(dsa_671b_model.init_llm_head_params().get_params())
+
profile_logs = get_profile_log_tensor(device=device_id, num_max_insts=65536)
result = (intermediates, caches, params, profile_logs)
self.multi_devices_results[device_id] = result
@@ -654,24 +669,29 @@ def __load_weights(device_id: int, model_path: str) -> None:
logger.info(f"Completed loading weights for device {device_id} in {time_str}")
threads = []
+ exceptions: list[Exception | None] = [None] * self.num_devices
for device_id in range(self.num_devices):
- thread = threading.Thread(target=__load_weights, args=(device_id, model_path))
+
+ def _runner(dev_id: int) -> None:
+ try:
+ __load_weights(dev_id, model_path)
+ except Exception as exc: # pragma: no cover - surfaced after join
+ exceptions[dev_id] = exc
+
+ thread = threading.Thread(target=_runner, args=(device_id,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
+ for device_id, exc in enumerate(exceptions):
+ if exc is not None:
+ raise RuntimeError(f"Failed to initialize device {device_id}: {exc}") from exc
+ # Prepare money for all devices
for device_id in range(self.num_devices):
with torch.cuda.device(device_id):
intermediates, caches, params, profile_logs = self._get_device_result(device_id)
- dsa_show_hands_prepare_money(
- True, # enable fused op
- self.enable_fp8_ops,
- params,
- intermediates,
- caches,
- profile_logs,
- )
+ self._prepare_money(params, intermediates, caches, profile_logs)
def from_pretrained(self, model_path: str) -> None:
"""Load the model weights from the given path."""
@@ -687,14 +707,14 @@ def forward(
self,
token_id: torch.Tensor,
) -> list[DeviceResult]:
- dsa_show_hands(token_id.cpu())
+ self._show_hands(token_id)
return [self._get_device_result(device_id) for device_id in range(self.num_devices)]
def reset_sequence(self) -> None:
- dsa_show_hands_reset(self.placeholder)
+ self._reset_sequence_impl()
def cleanup(self) -> None:
- dsa_show_hands_go_home(self.placeholder)
+ self._cleanup_impl()
def __del__(self) -> None:
try:
@@ -715,7 +735,7 @@ def __init__(
max_new_tokens: int = 100,
temperature: float = 1.0,
model_weights_dir: str = "",
- enable_fp8_ops: bool = False,
+ with_mtp: bool = False,
):
"""Initialize the ShowHandsGenerator.
@@ -723,12 +743,14 @@ def __init__(
max_new_tokens: Maximum number of new tokens to generate. Defaults to 100.
temperature: Temperature for sampling. Defaults to 1.0.
model_weights_dir: Path of the model weights directory.
+ with_mtp: Whether to use MTP (Multi-Token Prediction) for speculative decoding.
"""
torch.set_num_threads(64)
self.model_weights_dir = model_weights_dir
self.max_new_tokens = max_new_tokens
self.temperature = temperature
+ self.with_mtp = with_mtp
self.config = ModelArgsV3_2()
self.tokenizer = AutoTokenizer.from_pretrained(self.model_weights_dir)
@@ -737,16 +759,27 @@ def __init__(
self.default_device = torch.device("cuda:0")
- self.decode_layer = ShowHandsDSALayer(
- max_seq_len=self.config.max_seq_len,
- model_path=self.model_weights_dir,
- enable_fp8_ops=enable_fp8_ops,
- )
+ if with_mtp:
+ from tilert.models.deepseek_v3_2.dsa_mtp_e2e_show_hands import ShowHandsDsaMtpE2eLayer
+
+ self.decode_layer = ShowHandsDsaMtpE2eLayer(
+ max_seq_len=self.config.max_seq_len,
+ )
+ self.mtp_seq_len = self.decode_layer.MTP_SEQ_LEN # 4
+ else:
+ self.decode_layer = ShowHandsDSALayer(
+ max_seq_len=self.config.max_seq_len,
+ model_path=self.model_weights_dir,
+ )
def init(self) -> None:
"""Initialize the ShowHandsGenerator."""
tilert_init()
+ def cleanup(self) -> None:
+ """Cleanup the ShowHandsGenerator."""
+ self.decode_layer.cleanup()
+
def init_random_weights(self) -> None:
"""Random initialize the weights."""
self.decode_layer.init_random_weights()
@@ -756,8 +789,20 @@ def from_pretrained(self) -> None:
self.decode_layer.from_pretrained(self.model_weights_dir)
@torch.inference_mode()
- def generate(self, prompt: str) -> str:
- """Main function to load the model and perform single sequence generation."""
+ def generate(self, prompt: str, print_log: bool = True) -> tuple[str, list[float], list[int]]:
+ """Main function to load the model and perform single sequence generation.
+
+ Returns:
+ Tuple of (result_text, time_list, accepted_counts).
+ accepted_counts is empty for non-MTP mode.
+ """
+ if self.with_mtp:
+ return self._generate_with_mtp(prompt, print_log)
+ result, time_list = self._generate_without_mtp(prompt, print_log)
+ return result, time_list, [] # Empty accepted_counts for non-MTP
+
+ def _generate_without_mtp(self, prompt: str, print_log: bool = True) -> tuple[str, list[float]]:
+ """Standard generation without MTP."""
prompt_tokens = self.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}], add_generation_prompt=True
)
@@ -788,7 +833,7 @@ def generate(self, prompt: str) -> str:
intermediates, *_ = multi_devices_results[0]
intermediates_mapper = IntermediateMapper(list(intermediates[-TempVars.num_params() :]))
- next_token = intermediates_mapper.token_out[0]
+ next_token = intermediates_mapper.token_out[0][0] # only the first token
# replace the next token with the prompt token if the prompt mask is True
next_token = torch.where(
@@ -801,17 +846,19 @@ def generate(self, prompt: str) -> str:
decoded_tokens = self.tokenizer.decode(
[next_token.item()], skip_special_tokens=True
)
- print(decoded_tokens, end="", flush=True)
+ if print_log:
+ print(decoded_tokens, end="", flush=True)
if finished.all():
break
- print("\n")
- logger.info(f"--Number of tokens generated: {len(time_list)}")
+ if print_log:
+ print("\n")
+ logger.info(f"--Number of tokens generated: {len(time_list)}")
- # skip the first several samples to avoid the warmup effect
- stats_time(time_list[5:], "==== Performance ====")
- print("\n")
+ # skip the first several samples to avoid the warmup effect
+ stats_time(time_list[5:], "==== Performance ====")
+ print("\n")
# Reset sequence after generation, i.e. reset the cur_pos to 0 internally
self.decode_layer.reset_sequence()
@@ -825,4 +872,169 @@ def generate(self, prompt: str) -> str:
decoded_tokens = self.tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
- return f"{decoded_tokens[0]}\n" if decoded_tokens else ""
+ return f"{decoded_tokens[0]}\n" if decoded_tokens else "", time_list
+
+ def _generate_with_mtp(
+ self, prompt: str, print_log: bool = True
+ ) -> tuple[str, list[float], list[int]]:
+ """Generation with MTP (Multi-Token Prediction) speculative decoding."""
+ prompt_tokens = self.tokenizer.apply_chat_template(
+ [{"role": "user", "content": prompt}], add_generation_prompt=True
+ )
+
+ max_seq_len = self.config.max_seq_len
+ prompt_len = len(prompt_tokens)
+ total_len = min(max_seq_len, self.max_new_tokens + prompt_len)
+
+ # Output tokens buffer
+ tokens = torch.full(
+ (self.batch_size, total_len), -1, dtype=torch.long, device=self.default_device
+ )
+ tokens[0, :prompt_len] = torch.tensor(
+ prompt_tokens, dtype=torch.long, device=self.default_device
+ )
+
+ prefill_time_list = []
+ decode_time_list = []
+ decode_accepted_counts = [] # Only track decode phase for statistics
+ cur_pos = 0 # Current position in the output sequence
+
+ # Prefill phase: process prompt tokens in chunks
+ # Process prompt tokens in chunks of mtp_seq_len (with overlap)
+ while cur_pos < prompt_len - 1:
+ draft_end = min(cur_pos + self.mtp_seq_len, prompt_len)
+ draft_tokens = tokens[0, cur_pos:draft_end].clone()
+ actual_token_count = draft_tokens.shape[0]
+
+ # Pad if needed (use last token for padding)
+ if actual_token_count < self.mtp_seq_len:
+ pad_token = draft_tokens[-1].item()
+ padding = torch.full(
+ (self.mtp_seq_len - actual_token_count,),
+ pad_token,
+ dtype=torch.long,
+ device=self.default_device,
+ )
+ draft_tokens = torch.cat([draft_tokens, padding])
+
+ draft_tokens = draft_tokens.reshape(1, self.mtp_seq_len).to(torch.int32)
+
+ # Tell GPU how many tokens are valid (for cur_pos advancement)
+ self.decode_layer.set_prefill_valid_tokens(actual_token_count)
+
+ start_time = time.time()
+ self.decode_layer.forward(draft_tokens)
+ end_time = time.time()
+ prefill_time_list.append(end_time - start_time)
+
+ # Advance cur_pos by (actual_token_count - 1) to maintain overlap
+ # This ensures cur_pos ends at prompt_len - 1 after all chunks
+ cur_pos += actual_token_count - 1
+
+ # Decode phase: speculative decoding
+ # Set prefill_valid_tokens to 0 to switch to decode mode
+ self.decode_layer.set_prefill_valid_tokens(0)
+
+ finished = False
+ while cur_pos < total_len - 1 and not finished:
+ # Get next_draft_tokens from previous iteration
+ # (or use last prompt tokens for first decode)
+ if cur_pos == prompt_len - 1:
+ # First decode iteration: use last prompt token repeated as placeholder drafts
+ # We can't use [t6, t7, t8, t9] because that would apply wrong RoPE positions
+ # (cur_pos=9 means positions 9,10,11,12, but t6 should be at position 6)
+ last_token = tokens[0, prompt_len - 1].item()
+ draft_tokens = torch.full(
+ (self.mtp_seq_len,),
+ last_token,
+ dtype=torch.long,
+ device=self.default_device,
+ )
+ draft_tokens = draft_tokens.reshape(1, self.mtp_seq_len).to(torch.int32)
+ else:
+ # Use next_draft_tokens from previous iteration
+ draft_tokens = self.decode_layer.get_next_draft_tokens(0).reshape(
+ 1, self.mtp_seq_len
+ )
+
+ start_time = time.time()
+ self.decode_layer.forward(draft_tokens)
+ end_time = time.time()
+ decode_time_list.append(end_time - start_time)
+
+ num_accepted = self.decode_layer.get_num_accepted(0)
+ # Use predicted_tokens for output (not next_draft_tokens which is for next iteration)
+ predicted_tokens = self.decode_layer.get_predicted_tokens(0).flatten()
+ decode_accepted_counts.append(num_accepted)
+
+ # Add accepted tokens to output
+ num_output_tokens = num_accepted
+ for i in range(num_output_tokens):
+ if cur_pos + 1 + i >= total_len:
+ break
+ new_token = int(predicted_tokens[i].item())
+ tokens[0, cur_pos + 1 + i] = new_token
+
+ # Print generated token
+ if cur_pos + 1 + i >= prompt_len and print_log:
+ decoded_text = self.tokenizer.decode([new_token], skip_special_tokens=True)
+ print(decoded_text, end="", flush=True)
+
+ # Check for EOS
+ if new_token == self.eos_id:
+ finished = True
+ break
+
+ cur_pos += num_accepted
+
+ if print_log:
+ print("\n")
+ total_tokens = sum(decode_accepted_counts)
+ logger.info(f"--Number of forward calls (decode): {len(decode_accepted_counts)}")
+ logger.info(f"--Total tokens generated: {total_tokens}")
+ if len(decode_accepted_counts) > 0:
+ avg_accepted = sum(decode_accepted_counts) / len(decode_accepted_counts)
+ min_accepted = min(decode_accepted_counts)
+ max_accepted = max(decode_accepted_counts)
+ logger.info(
+ f"--Accepted tokens per call: mean={avg_accepted:.2f}, "
+ f"min={min_accepted}, max={max_accepted}"
+ )
+
+ # Calculate correct TPS accounting for MTP's multiple tokens per call
+ if len(decode_time_list) > 5:
+ total_decode_time = sum(decode_time_list[5:]) # skip warmup
+ tokens_after_warmup = (
+ sum(decode_accepted_counts[5:])
+ if len(decode_accepted_counts) > 5
+ else total_tokens
+ )
+ effective_tps = (
+ tokens_after_warmup / total_decode_time if total_decode_time > 0 else 0
+ )
+ avg_time_ms = total_decode_time / len(decode_time_list[5:]) * 1000
+ logger.info(f"--Avg forward time: {avg_time_ms:.2f}ms")
+ logger.info(f"--Effective TPS (with MTP): {effective_tps:.2f} tokens/s")
+
+ print("\n")
+
+ # Reset sequence after generation
+ self.decode_layer.reset_sequence()
+
+ # Extract completion tokens
+ completion_tokens = []
+ for _, toks in enumerate(tokens.tolist()):
+ toks = toks[prompt_len : prompt_len + self.max_new_tokens]
+ # Remove -1 padding and tokens after EOS
+ toks = [t for t in toks if t != -1]
+ if self.eos_id in toks:
+ toks = toks[: toks.index(self.eos_id)]
+ completion_tokens.append(toks)
+
+ decoded_tokens = self.tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
+
+ return (
+ f"{decoded_tokens[0]}\n" if decoded_tokens else "",
+ decode_time_list,
+ decode_accepted_counts,
+ )
diff --git a/python/models/deepseek_v3_2/params.py b/python/models/deepseek_v3_2/params.py
index 30c6b8d..a3914f8 100644
--- a/python/models/deepseek_v3_2/params.py
+++ b/python/models/deepseek_v3_2/params.py
@@ -2,7 +2,8 @@
import torch
-from tilert.models.utils import SwizzleMode, gen_tensor_swizzle_map_1d
+from tilert.models.deepseek_v3_2.model_args import ModelArgs as ModelArgsV3_2
+from tilert.models.utils import SwizzleMode, gen_tensor_swizzle_map_1d, precompute_freqs_cis
__all__ = [
"IntermediateMapper",
@@ -196,6 +197,33 @@ def num_params() -> int:
raise NotImplementedError("Subclasses must implement this method")
+class PlaceHolderParams(BaseParams):
+ def __init__(self) -> None:
+ super().__init__()
+
+ @staticmethod
+ def num_params() -> int:
+ return 0
+
+
+class EmbeddingParams(BaseParams):
+ def __init__(
+ self,
+ embedding: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ ):
+ super().__init__()
+ self.embedding = self.register_params(embedding)
+ self.freqs_cis = self.register_params(freqs_cis)
+
+ @staticmethod
+ def num_params() -> int:
+ return 1
+
+ def to_dict(self, device_id: int) -> dict[str, torch.Tensor]:
+ return {"model.embed_tokens.weight": self.embedding.to(device_id)}
+
+
class MlaParams(BaseParams):
def __init__(
self,
@@ -337,47 +365,59 @@ def to_dict(self, layer_id: int, device_id: int) -> dict[str, torch.Tensor]:
}
-class MlaFp8Params(BaseParams):
- def __init__(
- self,
- unproj_o_weights_and_scales: torch.Tensor,
- ) -> None:
- super().__init__()
- self.unproj_o_weights_and_scales = self.register_params(unproj_o_weights_and_scales)
-
- @staticmethod
- def num_params() -> int:
- return 1
-
+class LLMHeadParams(BaseParams):
+ """LLM Head Parameters"""
-class MLPFp8Params(BaseParams):
def __init__(
self,
- upgate_weights_and_scales: torch.Tensor,
- down_weights_and_scales: torch.Tensor,
+ hidden_rms_gamma: torch.Tensor,
+ head_proj_weights: torch.Tensor,
) -> None:
super().__init__()
- self.upgate_weights_and_scales = self.register_params(upgate_weights_and_scales)
- self.down_weights_and_scales = self.register_params(down_weights_and_scales)
+ self.hidden_rms_gamma = self.register_params(hidden_rms_gamma)
+ self.head_proj_weights = self.register_params(head_proj_weights)
@staticmethod
def num_params() -> int:
return 2
+ def to_dict(self, layer_id: int, device_id: int) -> dict[str, torch.Tensor]:
+ return {
+ f"layer_{layer_id}_model.norm.weight_dev_{device_id}": self.hidden_rms_gamma.to(
+ device_id
+ ),
+ f"layer_{layer_id}_lm_head.weight_dev_{device_id}": self.head_proj_weights.to(
+ device_id
+ ),
+ }
+
-class MoEFp8Params(BaseParams):
+class MTPPreprocessParams(BaseParams):
def __init__(
self,
- exp_upgate_weights_and_scales: torch.Tensor,
- exp_down_weights_and_scales: torch.Tensor,
+ embedding_rmsnorm_gamma: torch.Tensor,
+ hidden_rmsnorm_gamma: torch.Tensor,
+ eh_proj_weights: torch.Tensor,
) -> None:
super().__init__()
- self.exp_upgate_weights_and_scales = self.register_params(exp_upgate_weights_and_scales)
- self.exp_down_weights_and_scales = self.register_params(exp_down_weights_and_scales)
+ self.embedding_rmsnorm_gamma = self.register_params(embedding_rmsnorm_gamma)
+ self.hidden_rmsnorm_gamma = self.register_params(hidden_rmsnorm_gamma)
+ self.eh_proj_weights = self.register_params(eh_proj_weights)
@staticmethod
def num_params() -> int:
- return 2
+ return 3
+
+ def to_dict(self, layer_id: int, device_id: int) -> dict[str, torch.Tensor]:
+ return {
+ f"layer_{layer_id}_embedding_rmsnorm_gamma_dev_{device_id}": (
+ self.embedding_rmsnorm_gamma.to(device_id)
+ ),
+ f"layer_{layer_id}_hidden_rmsnorm_gamma_dev_{device_id}": self.hidden_rmsnorm_gamma.to(
+ device_id
+ ),
+ f"layer_{layer_id}_eh_proj_weights_dev_{device_id}": self.eh_proj_weights.to(device_id),
+ }
class TempVars(BaseParams):
@@ -409,6 +449,19 @@ def __init__(
x_rmsnorm: torch.Tensor,
logits_out: torch.Tensor,
token_out: torch.Tensor,
+ embedding_rmsnorm: torch.Tensor,
+ hidden_rmsnorm: torch.Tensor,
+ eh_proj: torch.Tensor,
+ x_tensor: torch.Tensor,
+ rope_freqs: torch.Tensor,
+ cur_pos: torch.Tensor,
+ token_id: torch.Tensor,
+ last_hidden_states: torch.Tensor,
+ draft_tokens: torch.Tensor,
+ predicted_tokens: torch.Tensor,
+ predicted_hidden: torch.Tensor,
+ accepted_tokens: torch.Tensor,
+ next_draft_tokens: torch.Tensor,
) -> None:
super().__init__()
self.q = self.register_params(q)
@@ -437,10 +490,23 @@ def __init__(
self.x_rmsnorm = self.register_params(x_rmsnorm)
self.logits_out = self.register_params(logits_out)
self.token_out = self.register_params(token_out)
+ self.embedding_rmsnorm = self.register_params(embedding_rmsnorm)
+ self.hidden_rmsnorm = self.register_params(hidden_rmsnorm)
+ self.eh_proj = self.register_params(eh_proj)
+ self.x_tensor = self.register_params(x_tensor)
+ self.rope_freqs = self.register_params(rope_freqs)
+ self.cur_pos = self.register_params(cur_pos)
+ self.token_id = self.register_params(token_id)
+ self.last_hidden_states = self.register_params(last_hidden_states)
+ self.draft_tokens = self.register_params(draft_tokens)
+ self.predicted_tokens = self.register_params(predicted_tokens)
+ self.predicted_hidden = self.register_params(predicted_hidden)
+ self.accepted_tokens = self.register_params(accepted_tokens)
+ self.next_draft_tokens = self.register_params(next_draft_tokens)
@staticmethod
def num_params() -> int:
- return 26
+ return 39
def tot_size_in_bytes_aligned(self, aligned_size: int) -> int:
tot_size: int = 0
@@ -482,28 +548,395 @@ def num_params() -> int:
return 3
-class LLMHeadParams(BaseParams):
- """LLM Head Parameters"""
+class Dsa671BModelInitializer:
+ """DSA with MTP e2e model for DeepSeek v3.2"""
+
+ # TODO: These parameters should be carefully checked
+ BATCH_SIZE = 1
+ MAX_SEQ_LEN = 4
+ NUM_HEADS = 16
+ NUM_KI_HEADS = 64
+
+ MAX_OPS = 2048
+ MAX_SEL_TOKENS = 2048
+ MAX_CTX_LEN = 16384
+ NUM_DENSE_LAYERS = 3
+ NUM_MOE_LAYERS = 58
+ NUM_LAYERS = NUM_DENSE_LAYERS + NUM_MOE_LAYERS
+
+ HIDDEN_SIZE = 7168
+ PE_LORA_DIM = 64
+ Q_NOPE_DIM = 128
+
+ Q_DIM = 1536
+ KV_CACHE_DIM = 512
+ PE_CACHE_DIM = 64
+ KI_CACHE_DIM = 128
+ Q_PE_DIM = 512
+ V_HEAD_DIM = 128
+
+ N_ROUTED_EXPERTS = 256
+ N_ACTIVATE_EXPERTS = 8
+ N_TOTAL_EXPERTS = N_ACTIVATE_EXPERTS + 1
+ EXP_DIMS = 256
+
+ FULL_VOCAB_SIZE = 129280
+ VOCAB_SIZE = FULL_VOCAB_SIZE // 8 # 16160
def __init__(
self,
- hidden_rms_gamma: torch.Tensor,
- head_proj_weights: torch.Tensor,
+ device: torch.device,
+ max_seq_len: int | None = None,
+ max_ctx_len: int | None = None,
+ with_weight_conversion: bool = True,
+ with_mtp: bool = False,
) -> None:
super().__init__()
- self.hidden_rms_gamma = self.register_params(hidden_rms_gamma)
- self.head_proj_weights = self.register_params(head_proj_weights)
- @staticmethod
- def num_params() -> int:
- return 2
+ self.device = device
+ self.max_seq_len = max_seq_len if max_seq_len is not None else self.MAX_SEQ_LEN
+ self.max_ctx_len = max_ctx_len if max_ctx_len is not None else self.MAX_CTX_LEN
+ self.with_weight_conversion = with_weight_conversion
+ self.with_mtp = with_mtp
+
+ self.bf16_desc = {"dtype": torch.bfloat16, "device": device}
+ self.fp16_desc = {"dtype": torch.float16, "device": device}
+ self.fp32_desc = {"dtype": torch.float32, "device": device}
+ self.uint64_desc = {"dtype": torch.uint64, "device": device}
+ self.int32_desc = {"dtype": torch.int32, "device": device}
+ self.uint8_desc = {"dtype": torch.uint8, "device": device}
+
+ self.mtp_params_sidx = 0
+
+ def register_weights_and_scales(
+ self, dim1: int, dim2: int
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ block_size = 128
+ weights_dims = (dim1, dim2)
+ weights = torch.randn(weights_dims, **self.bf16_desc).to(torch.float8_e4m3fn)
+ scales = torch.randn((dim1, dim2 // block_size), **self.bf16_desc)
+ return weights, scales
+
+ def init_llm_head_params(self) -> LLMHeadParams:
+ from tilert.models.preprocess.weight_utils import RMSNormHeadProjWeightsConverter
+
+ hidden_rms_gamma_shape = (self.HIDDEN_SIZE,)
+ head_proj_weights_shape = (self.VOCAB_SIZE, self.HIDDEN_SIZE)
+
+ hidden_rms_gamma = torch.randn(hidden_rms_gamma_shape, **self.fp32_desc)
+ head_proj_weights = torch.randn(head_proj_weights_shape, **self.bf16_desc)
+
+ if self.with_weight_conversion:
+ # Apply weight conversion for LLM head
+ head_proj_weights = (
+ RMSNormHeadProjWeightsConverter.tilert_to_tilert_native_bf16_warp_gemv(
+ head_proj_weights
+ )
+ )
- def to_dict(self, layer_id: int, device_id: int) -> dict[str, torch.Tensor]:
- return {
- f"layer_{layer_id}_model.norm.weight_dev_{device_id}": self.hidden_rms_gamma.to(
- device_id
- ),
- f"layer_{layer_id}_lm_head.weight_dev_{device_id}": self.head_proj_weights.to(
- device_id
- ),
- }
+ return LLMHeadParams(
+ hidden_rms_gamma,
+ head_proj_weights,
+ )
+
+ def init_embedding_params(self) -> EmbeddingParams:
+ embedding = torch.randn(self.FULL_VOCAB_SIZE, self.HIDDEN_SIZE, **self.bf16_desc)
+ freqs_cis = precompute_freqs_cis(ModelArgsV3_2())
+ freqs_cis = torch.view_as_real(freqs_cis).reshape(freqs_cis.shape[0], -1)
+ return EmbeddingParams(embedding, freqs_cis.to(self.device))
+
+ def init_mla_params(self) -> MlaParams:
+ from tilert.models.preprocess.weight_utils import (
+ RMSNormProjQAKVAKIWeightsConverter,
+ )
+
+ qkv_dim = self.Q_DIM + self.KV_CACHE_DIM + self.PE_LORA_DIM + 128
+ x_rmsnorm_gamma_shape = (self.HIDDEN_SIZE,)
+ q_wb_shape = ((self.PE_LORA_DIM + self.Q_NOPE_DIM + 512) * self.NUM_HEADS, self.Q_DIM)
+ wkv_b1_shape = (self.NUM_HEADS, self.KV_CACHE_DIM, self.V_HEAD_DIM)
+ wkv_b2_shape = (self.NUM_HEADS, self.V_HEAD_DIM, self.KV_CACHE_DIM)
+ wkv_b2_scales_shape = (self.NUM_HEADS, self.V_HEAD_DIM // 128, self.KV_CACHE_DIM // 128)
+ unproj_w_shape = (self.HIDDEN_SIZE, self.NUM_HEADS * self.V_HEAD_DIM)
+ unproj_scales_shape = (896, self.NUM_HEADS * self.V_HEAD_DIM // 128)
+
+ x_rmsnorm_gamma = torch.randn(x_rmsnorm_gamma_shape, **self.fp32_desc)
+ qkv_wa_weights, _ = self.register_weights_and_scales(qkv_dim, self.HIDDEN_SIZE)
+ qkv_wa_scales = torch.randn((130, 64), **self.bf16_desc)
+ k_weights = torch.randn(128, **self.fp32_desc)
+ k_bias = torch.randn(128, **self.fp32_desc)
+ q_rmsnorm_gamma = torch.randn(self.Q_DIM, **self.fp32_desc)
+ q_wb_weights, _ = self.register_weights_and_scales(*q_wb_shape)
+ q_wb_scales = torch.randn((448, 12), **self.bf16_desc)
+ id_score_weights = torch.randn(64, self.HIDDEN_SIZE, **self.bf16_desc)
+ wkv_b1_weights = torch.randn(wkv_b1_shape, **self.fp16_desc).to(torch.float8_e4m3fn)
+ wkv_b1_scales = torch.randn((16, 8, 1), **self.bf16_desc)
+ kv_rmsnorm_gamma = torch.randn(self.KV_CACHE_DIM, **self.fp32_desc)
+ wkv_b2_weights = torch.randn(wkv_b2_shape, **self.fp16_desc).to(torch.float8_e4m3fn)
+ wkv_b2_scales = torch.randn(wkv_b2_scales_shape, **self.bf16_desc)
+ unproj_weights = torch.randn(unproj_w_shape, **self.fp16_desc).to(torch.float8_e4m3fn)
+ unproj_scales = torch.randn(unproj_scales_shape, **self.bf16_desc)
+
+ if self.with_weight_conversion:
+ # Apply weight conversion for MLA qkv_wa weights
+ # Convert tilert format -> common format -> tilert native bf16 warp gemv format
+ common_weights = RMSNormProjQAKVAKIWeightsConverter.tilert_to_common(
+ qkv_wa_weights, qkv_wa_scales, x_rmsnorm_gamma
+ )
+ qkv_wa_weights, x_rmsnorm_gamma = (
+ RMSNormProjQAKVAKIWeightsConverter.common_to_tilert_native_bf16_warp_gemv(
+ *common_weights
+ )
+ )
+
+ return MlaParams(
+ x_rmsnorm_gamma=x_rmsnorm_gamma,
+ qkv_wa_weights=qkv_wa_weights,
+ qkv_wa_scales=qkv_wa_scales,
+ k_weights=k_weights,
+ k_bias=k_bias,
+ q_rmsnorm_gamma=q_rmsnorm_gamma,
+ q_wb_weights=q_wb_weights,
+ q_wb_scales=q_wb_scales,
+ id_score_weights=id_score_weights,
+ wkv_b1_weights=wkv_b1_weights,
+ wkv_b1_scales=wkv_b1_scales,
+ kv_rmsnorm_gamma=kv_rmsnorm_gamma,
+ wkv_b2_weights=wkv_b2_weights,
+ wkv_b2_scales=wkv_b2_scales,
+ unproj_weights=unproj_weights,
+ unproj_scales=unproj_scales,
+ )
+
+ def init_mlp_params(self) -> MLPParams:
+ exp_upgate_w_shape = (9, self.EXP_DIMS * 2, self.HIDDEN_SIZE)
+ exp_upgate_s_shape = (9, self.EXP_DIMS * 2 // 128, 64)
+ exp_down_w_shape = (9, self.HIDDEN_SIZE, self.EXP_DIMS)
+ exp_down_s_shape = (9, 1024, self.EXP_DIMS // 128)
+
+ unproj_o_gamma = torch.randn(self.HIDDEN_SIZE, **self.fp32_desc)
+ upgate_weights = torch.randn(exp_upgate_w_shape, **self.fp16_desc).to(torch.float8_e4m3fn)
+ upgate_scales = torch.randn(exp_upgate_s_shape, **self.bf16_desc)
+ down_weights = torch.randn(exp_down_w_shape, **self.fp16_desc).to(torch.float8_e4m3fn)
+ down_scales = torch.randn(exp_down_s_shape, **self.bf16_desc)
+
+ return MLPParams(
+ unproj_o_gamma,
+ upgate_weights,
+ upgate_scales,
+ down_weights,
+ down_scales,
+ )
+
+ def init_moe_params(self) -> MoEParams:
+ from tilert.models.preprocess.weight_utils import (
+ ExpertSelectUpGateSiLUWeightsConverter,
+ )
+
+ exp_ug_w_shape = (self.N_ROUTED_EXPERTS + 1, self.EXP_DIMS * 2, self.HIDDEN_SIZE)
+ exp_upgate_s_shape = (self.N_ROUTED_EXPERTS + 1, self.EXP_DIMS * 2 // 128, 64)
+ exp_down_w_shape = (self.N_ROUTED_EXPERTS + 1, self.HIDDEN_SIZE, self.EXP_DIMS)
+ exp_down_s_shape = (self.N_ROUTED_EXPERTS + 1, 1024, self.EXP_DIMS // 128)
+
+ unproj_o_gamma = torch.randn(self.HIDDEN_SIZE, **self.fp32_desc)
+ exp_proj_weights = torch.randn((self.N_ROUTED_EXPERTS, self.HIDDEN_SIZE), **self.bf16_desc)
+ exp_bias = torch.randn(self.N_ROUTED_EXPERTS, **self.fp32_desc)
+ exp_upgate_weights = torch.randn(exp_ug_w_shape, **self.fp16_desc).to(torch.float8_e4m3fn)
+ exp_upgate_scales = torch.randn(exp_upgate_s_shape, **self.bf16_desc)
+ exp_down_weights = torch.randn(exp_down_w_shape, **self.fp16_desc).to(torch.float8_e4m3fn)
+ exp_down_scales = torch.randn(exp_down_s_shape, **self.bf16_desc)
+
+ if self.with_weight_conversion:
+ # Apply weight conversion for MOE exp_upgate weights
+ exp_upgate_weights = ExpertSelectUpGateSiLUWeightsConverter.tilert_to_tilert_144sm_mma(
+ exp_upgate_weights, exp_upgate_scales
+ )
+
+ return MoEParams(
+ unproj_o_gamma,
+ exp_proj_weights,
+ exp_bias,
+ exp_upgate_weights,
+ exp_upgate_scales,
+ exp_down_weights,
+ exp_down_scales,
+ )
+
+ def init_mtp_preprocess_params(self) -> MTPPreprocessParams:
+ """Initialize MTP preprocess parameters with random values."""
+ embedding_rmsnorm_gamma = torch.randn(self.HIDDEN_SIZE, **self.fp32_desc)
+ hidden_rmsnorm_gamma = torch.randn(self.HIDDEN_SIZE, **self.fp32_desc)
+ eh_proj_weights = torch.randn((128, 7, 56, 256), **self.bf16_desc)
+ return MTPPreprocessParams(
+ embedding_rmsnorm_gamma,
+ hidden_rmsnorm_gamma,
+ eh_proj_weights,
+ )
+
+ def acquire_params(self) -> list[torch.Tensor]:
+ params = []
+
+ for _ in range(self.NUM_DENSE_LAYERS):
+ params.extend(self.init_mla_params().get_params())
+ params.extend(self.init_mlp_params().get_params())
+
+ for _ in range(self.NUM_MOE_LAYERS):
+ params.extend(self.init_mla_params().get_params())
+ params.extend(self.init_moe_params().get_params())
+
+ params.extend(self.init_llm_head_params().get_params())
+ params.extend(self.init_embedding_params().get_params())
+
+ if self.with_mtp:
+ self.mtp_params_sidx = len(params)
+ params.extend(self.init_embedding_params().get_params())
+ params.extend(self.init_mtp_preprocess_params().get_params())
+ params.extend(self.init_mla_params().get_params())
+ params.extend(self.init_moe_params().get_params())
+ params.extend(self.init_llm_head_params().get_params())
+
+ return params
+
+ def acquire_temp_vars(self, seq_len: int | None = None) -> TempVars:
+ """Acquire temporary variables for the model.
+
+ Args:
+ seq_len: Sequence length for temp vars. If None, uses self.max_seq_len.
+
+ Returns:
+ TempVars object containing all temporary tensors.
+ """
+ seq_len = seq_len if seq_len is not None else self.max_seq_len
+ BATCH_SEQ = (self.BATCH_SIZE, seq_len)
+
+ q = torch.zeros(*BATCH_SEQ, self.Q_DIM, **self.bf16_desc)
+ kv = torch.zeros(*BATCH_SEQ, self.KV_CACHE_DIM, **self.bf16_desc)
+ q_pe = torch.zeros(*BATCH_SEQ, self.NUM_HEADS, self.PE_LORA_DIM, **self.bf16_desc)
+ ki = torch.zeros(*BATCH_SEQ, self.KI_CACHE_DIM, **self.bf16_desc)
+ q_nope_down = torch.zeros(*BATCH_SEQ, self.NUM_HEADS, self.V_HEAD_DIM, **self.bf16_desc)
+ q_nope = torch.zeros(*BATCH_SEQ, self.NUM_HEADS, self.Q_PE_DIM, **self.bf16_desc)
+ iq = torch.zeros(*BATCH_SEQ, self.NUM_KI_HEADS, self.KI_CACHE_DIM, **self.bf16_desc)
+ iq_rt = torch.zeros(*BATCH_SEQ, self.NUM_KI_HEADS, self.KI_CACHE_DIM, **self.bf16_desc)
+ idx_score = torch.zeros(*BATCH_SEQ, self.NUM_KI_HEADS, **self.bf16_desc)
+ idx_logits = torch.zeros(*BATCH_SEQ, self.max_ctx_len, **self.fp32_desc)
+ idx_sels = torch.zeros(*BATCH_SEQ, self.MAX_SEL_TOKENS, **self.int32_desc)
+ o = torch.zeros(*BATCH_SEQ, self.NUM_HEADS, self.KV_CACHE_DIM, **self.bf16_desc)
+ o_acc = torch.zeros(*BATCH_SEQ, self.NUM_HEADS, 32, self.KV_CACHE_DIM, **self.fp32_desc)
+ o_lse = torch.empty(*BATCH_SEQ, self.NUM_HEADS, **self.fp32_desc)
+ o_lse_acc = torch.empty(*BATCH_SEQ, self.NUM_HEADS, 32, **self.fp32_desc)
+ proj_o = torch.zeros(*BATCH_SEQ, self.NUM_HEADS, self.V_HEAD_DIM, **self.bf16_desc)
+ unproj_o = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc)
+ scores = torch.zeros(*BATCH_SEQ, self.N_ROUTED_EXPERTS, **self.fp32_desc)
+ x_mlp_in = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc)
+ exp_up_gate = torch.zeros(*BATCH_SEQ, self.N_TOTAL_EXPERTS, self.EXP_DIMS, **self.bf16_desc)
+ sel_probs = torch.zeros(*BATCH_SEQ, self.N_ACTIVATE_EXPERTS, **self.fp32_desc)
+ sel_indices = torch.zeros(*BATCH_SEQ, self.N_ACTIVATE_EXPERTS, **self.int32_desc)
+ exp_out = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc)
+ x_rmsnorm = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc)
+ logits_out = torch.zeros(*BATCH_SEQ, self.VOCAB_SIZE, **self.fp32_desc)
+ token_out = torch.zeros(*BATCH_SEQ, 1, **self.int32_desc)
+
+ embedding_rmsnorm = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc)
+ hidden_rmsnorm = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc)
+ eh_proj = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc)
+ x_tensor = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc)
+ rope_freqs = torch.zeros(*BATCH_SEQ, self.PE_CACHE_DIM, **self.fp32_desc)
+ cur_pos = torch.zeros(self.BATCH_SIZE, **self.int32_desc)
+ token_id = torch.zeros(*BATCH_SEQ, 1, **self.int32_desc)
+ last_hidden_states = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc)
+
+ draft_tokens = torch.zeros(*BATCH_SEQ, **self.int32_desc)
+ predicted_tokens = torch.zeros(*BATCH_SEQ, 1, **self.int32_desc)
+ predicted_hidden = torch.zeros(*BATCH_SEQ, self.HIDDEN_SIZE, **self.bf16_desc)
+ accepted_tokens = torch.zeros(self.BATCH_SIZE, **self.int32_desc)
+ next_draft_tokens = torch.zeros(*BATCH_SEQ, **self.int32_desc)
+
+ return TempVars(
+ q,
+ kv,
+ ki,
+ q_nope_down,
+ q_pe,
+ iq,
+ iq_rt,
+ idx_score,
+ idx_logits,
+ idx_sels,
+ q_nope,
+ o,
+ o_acc,
+ o_lse,
+ o_lse_acc,
+ proj_o,
+ unproj_o,
+ scores,
+ x_mlp_in,
+ exp_up_gate,
+ sel_probs,
+ sel_indices,
+ exp_out,
+ x_rmsnorm,
+ logits_out,
+ token_out,
+ embedding_rmsnorm,
+ hidden_rmsnorm,
+ eh_proj,
+ x_tensor,
+ rope_freqs,
+ cur_pos,
+ token_id,
+ last_hidden_states,
+ draft_tokens,
+ predicted_tokens,
+ predicted_hidden,
+ accepted_tokens,
+ next_draft_tokens,
+ )
+
+ def acquire_cache_vars(self, num_layers: int | None = None) -> list[torch.Tensor]:
+ """Acquire cache variables for the model.
+
+ Args:
+ num_layers: Number of layers to create cache for. If None, uses NUM_LAYERS.
+
+ Returns:
+ List of cache tensors (3 tensors per layer: k_cache, kv_cache, pe_cache).
+ """
+ num_layers = num_layers if num_layers is not None else self.NUM_LAYERS
+ if self.with_mtp:
+ num_layers += 1
+
+ BATCH_CTX = (self.BATCH_SIZE, self.max_ctx_len)
+ cache_vars = []
+ for _ in range(num_layers):
+ cache_vars.extend(
+ [
+ torch.zeros(*BATCH_CTX, self.KI_CACHE_DIM, **self.bf16_desc),
+ torch.zeros(*BATCH_CTX, self.KV_CACHE_DIM, **self.bf16_desc),
+ torch.zeros(*BATCH_CTX, self.PE_CACHE_DIM, **self.bf16_desc),
+ ]
+ )
+ return cache_vars
+
+ def acquire_single_layer_cache_vars(self) -> list[torch.Tensor]:
+ """Acquire cache variables for a single layer.
+
+ Returns:
+ List of 3 cache tensors: k_cache, kv_cache, pe_cache.
+ """
+ return self.acquire_cache_vars(num_layers=1)
+
+ def acquire_misc_vars(self) -> list[torch.Tensor]:
+ return [
+ torch.zeros(self.MAX_OPS, 148, 16, **self.uint64_desc),
+ torch.zeros(self.MAX_OPS, 128, **self.uint8_desc),
+ torch.zeros(self.MAX_OPS, 8, **self.uint8_desc),
+ ]
+
+ def get_mtp_all_vars(self) -> list[torch.Tensor]:
+ return [
+ self.acquire_params()[self.mtp_params_sidx :],
+ self.acquire_temp_vars().get_params(),
+ self.acquire_cache_vars()[-3:],
+ # Potential issue: Reallocate misc vars for MTP
+ self.acquire_misc_vars(),
+ ]
diff --git a/python/models/preprocess/weight_utils.py b/python/models/preprocess/weight_utils.py
index 9e936a5..5330e14 100644
--- a/python/models/preprocess/weight_utils.py
+++ b/python/models/preprocess/weight_utils.py
@@ -11,9 +11,10 @@
__all__ = [
"print_weights_info",
"WeightLoader",
+ "DownAllreduceWeightsConverter",
+ "RMSNormProjQAKVAKIWeightsConverter",
"RMSNormHeadProjWeightsConverter",
"ExpertSelectUpGateSiLUWeightsConverter",
- "RMSNormProjQAKVAKIRopeWeightsConverter",
"RMSNormUpGateSiLUWeightsConverter",
]
@@ -233,6 +234,101 @@ def get_weight(self, name: str, from_tilert: bool = False) -> Any:
return weight_dict[name]["data"]
+class RMSNormProjQAKVAKIWeightsConverter:
+ """Weights converter class."""
+
+ @staticmethod
+ def tilert_to_common(
+ tilert_wqkv_a: torch.Tensor,
+ tilert_wqkv_a_scales: torch.Tensor,
+ tilert_attn_norm_weight: torch.Tensor,
+ ) -> tuple[
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ ]:
+ """
+ Convert tilert weights to common weights.
+
+ Args:
+ tilert_wqkv_a: Tilert weight tensor.
+ tilert_wqkv_a_scales: Tilert weight scale tensor.
+ tilert_attn_norm_weight: Tilert attention norm weight tensor.
+ Returns:
+ tuple: Common weights.
+ """
+ wq_a = tilert_wqkv_a[:1536] # 1536, 7168
+ wkv_a = tilert_wqkv_a[1536 : 1536 + 576] # 576, 7168
+ wk = tilert_wqkv_a[1536 + 576 :] # 128, 7168
+
+ wqkv_a_scales_0 = tilert_wqkv_a_scales[:128, :].reshape(16, 8, 64)
+ wqkv_a_scales_0 = wqkv_a_scales_0[:, 0, :].reshape(16, 64)
+ wqkv_a_scales_1 = tilert_wqkv_a_scales[128:129, :] # 1, 64
+ wqkv_a_scales_2 = tilert_wqkv_a_scales[129:, :] # 1, 64
+ wqkv_a_scales_swizzled = torch.cat(
+ [wqkv_a_scales_0, wqkv_a_scales_1, wqkv_a_scales_2], dim=0
+ )
+ wqkv_scales = torch.zeros(
+ (18, 56), dtype=torch.bfloat16, device=tilert_wqkv_a_scales.device
+ )
+
+ for i in range(64):
+ if ((i % 8) * 8 + i // 8) < 56:
+ wqkv_scales[:, ((i % 8) * 8 + i // 8)] = wqkv_a_scales_swizzled[:, i]
+ wq_a_scale = wqkv_scales[:12, :] # 12, 56
+ wkv_a_scale = wqkv_scales[12:17, :] # 5, 56
+ wk_scale = wqkv_scales[17:, :] # 1, 56
+
+ attn_norm_weight = tilert_attn_norm_weight
+ return wq_a, wq_a_scale, wkv_a, wkv_a_scale, wk, wk_scale, attn_norm_weight
+
+ @staticmethod
+ def common_to_tilert_native_bf16_warp_gemv(
+ wq_a: torch.Tensor,
+ wq_a_scale: torch.Tensor,
+ wkv_a: torch.Tensor,
+ wkv_a_scale: torch.Tensor,
+ wk: torch.Tensor,
+ wk_scale: torch.Tensor,
+ attn_norm_weight: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Convert common weights to weights for tilert native bf16 warp gemv op.
+
+ Args:
+ wq_a: Common weight tensor.
+ wq_a_scale: Common weight scale tensor.
+ wkv_a: Common weight tensor.
+ wkv_a_scale: Common weight scale tensor.
+ wk: Common weight tensor.
+ wk_scale: Common weight scale tensor.
+ attn_norm_weight: Common attention norm weight tensor.
+ Returns:
+ tuple: Tilert weights for native bf16 warp gemv op.
+ """
+ wq_a_scale = wq_a_scale.reshape((12, 56, 1)).repeat(1, 1, 128).reshape((12, 1, 7168))
+ wq_a_scale = wq_a_scale.repeat(1, 128, 1).reshape((1536, 7168))
+ wkv_a_scale = wkv_a_scale.reshape((5, 56, 1)).repeat(1, 1, 128).reshape((5, 1, 7168))
+ wkv_a_scale = wkv_a_scale.repeat(1, 128, 1).reshape((-1, 7168))
+ wkv_a_scale = wkv_a_scale[:576]
+ wk_scale = wk_scale.reshape((1, 56, 1)).repeat(1, 1, 128).reshape((1, 1, 7168))
+ wk_scale = wk_scale.repeat(1, 128, 1).reshape((128, 7168))
+ wq_a = wq_a.reshape((1536, 7168)).float() * wq_a_scale.float()
+ wkv_a = wkv_a.reshape((576, 7168)).float() * wkv_a_scale.float()
+ wk = wk.reshape((128, 7168)).float() * wk_scale.float()
+ # concatenate the weights
+ weights = torch.cat([wq_a, wkv_a, wk], dim=0)
+ assert weights.shape == (1536 + 576 + 128, 7168)
+
+ weights = weights.reshape(140, 16, 7, 1024)
+ weights = weights.transpose(1, 2) # 140, 7, 16, 1024
+ return weights.to(torch.bfloat16).contiguous(), attn_norm_weight.clone()
+
+
class ExpertSelectUpGateSiLUWeightsConverter:
"""Weights converter class."""
@@ -244,9 +340,17 @@ def _swizzle_mma_16x32(mat_in: torch.Tensor) -> torch.Tensor:
mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4)
return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2)
+ @staticmethod
+ def _swizzle_mma_16x16(mat_in: torch.Tensor) -> torch.Tensor:
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 16
+ # PTX isa fig.88
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 2).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 2).transpose(-3, -2)
+
@staticmethod
def tilert_to_tilert_144sm(
- mat_in: torch.Tensor, mat_scale_in: torch.Tensor, swizzle_for_mma_16x32: bool = False
+ mat_in: torch.Tensor, mat_scale_in: torch.Tensor, mma_type: str | None = None
) -> torch.Tensor:
"""
Convert tilert weights and scales to tilert_144sm input format.
@@ -254,6 +358,7 @@ def tilert_to_tilert_144sm(
Args:
mat_in: tilert weights
mat_scale_in: tilert scales
+ mma_type: MMA type, None,"16x32" or "16x16"
Returns:
tilert_144sm weights and scales
"""
@@ -266,13 +371,20 @@ def tilert_to_tilert_144sm(
# to 16x1024 blocks
weights_w1 = weights_w1.reshape(exp_num, 16, 16, 7, 1024).transpose(2, 3)
weights_w3 = weights_w3.reshape(exp_num, 16, 16, 7, 1024).transpose(2, 3)
- if swizzle_for_mma_16x32:
+ if mma_type == "16x32":
weights_w1 = weights_w1.reshape(exp_num, 16, 7, 16, 32, 32).transpose(3, 4)
weights_w1 = ExpertSelectUpGateSiLUWeightsConverter._swizzle_mma_16x32(weights_w1)
weights_w1 = weights_w1.reshape(exp_num, 16, 7, 16, 1024)
weights_w3 = weights_w3.reshape(exp_num, 16, 7, 16, 32, 32).transpose(3, 4)
weights_w3 = ExpertSelectUpGateSiLUWeightsConverter._swizzle_mma_16x32(weights_w3)
weights_w3 = weights_w3.reshape(exp_num, 16, 7, 16, 1024)
+ elif mma_type == "16x16":
+ weights_w1 = weights_w1.reshape(exp_num, 16, 7, 16, 64, 16).transpose(3, 4)
+ weights_w1 = ExpertSelectUpGateSiLUWeightsConverter._swizzle_mma_16x16(weights_w1)
+ weights_w1 = weights_w1.reshape(exp_num, 16, 7, 16, 1024)
+ weights_w3 = weights_w3.reshape(exp_num, 16, 7, 16, 64, 16).transpose(3, 4)
+ weights_w3 = ExpertSelectUpGateSiLUWeightsConverter._swizzle_mma_16x16(weights_w3)
+ weights_w3 = weights_w3.reshape(exp_num, 16, 7, 16, 1024)
weights = torch.cat([weights_w1, weights_w3], dim=3)
assert weights.shape == (exp_num, 16, 7, 32, 1024)
@@ -303,8 +415,7 @@ def tilert_to_tilert_144sm(
@staticmethod
def tilert_to_tilert_144sm_mma(
- mat_in: torch.Tensor,
- mat_scale_in: torch.Tensor,
+ mat_in: torch.Tensor, mat_scale_in: torch.Tensor, mma_type: str = "16x32"
) -> torch.Tensor:
"""
Convert tilert weights and scales to tilert_144sm_mma input format.
@@ -316,7 +427,7 @@ def tilert_to_tilert_144sm_mma(
tilert_144sm weights and scales
"""
return ExpertSelectUpGateSiLUWeightsConverter.tilert_to_tilert_144sm(
- mat_in, mat_scale_in, True
+ mat_in, mat_scale_in, mma_type
)
@@ -333,296 +444,125 @@ def tilert_to_tilert_native_bf16_warp_gemv(
return weights.contiguous()
-class RMSNormProjQAKVAKIRopeWeightsConverter:
+class RMSNormUpGateSiLUWeightsConverter:
"""Weights converter class."""
@staticmethod
- def tilert_to_common(
- tilert_wqkv_a: torch.Tensor,
- tilert_wqkv_a_scales: torch.Tensor,
- tilert_attn_norm_weight: torch.Tensor,
- ) -> tuple[
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- ]:
+ def tilert_to_tilert_144sm(
+ mat_in: torch.Tensor,
+ mat_scale_in: torch.Tensor,
+ mma_type: str | None = None,
+ ) -> torch.Tensor:
"""
- Convert tilert weights to common weights.
+ Convert tilert weights and scales to tilert_144sm input format.
Args:
- tilert_wqkv_a: Tilert weight tensor.
- tilert_wqkv_a_scales: Tilert weight scale tensor.
- tilert_attn_norm_weight: Tilert attention norm weight tensor.
+ mat_in: tilert weights
+ mat_scale_in: tilert scales
Returns:
- tuple: Common weights.
+ tilert_144sm weights and scales
"""
- wq_a = tilert_wqkv_a[:1536] # 1536, 7168
- wkv_a = tilert_wqkv_a[1536 : 1536 + 576] # 576, 7168
- wk = tilert_wqkv_a[1536 + 576 :] # 128, 7168
-
- wqkv_a_scales_0 = tilert_wqkv_a_scales[:128, :].reshape(16, 8, 64)
- wqkv_a_scales_0 = wqkv_a_scales_0[:, 0, :].reshape(16, 64)
- wqkv_a_scales_1 = tilert_wqkv_a_scales[128:129, :] # 1, 64
- wqkv_a_scales_2 = tilert_wqkv_a_scales[129:, :] # 1, 64
- wqkv_a_scales_swizzled = torch.cat(
- [wqkv_a_scales_0, wqkv_a_scales_1, wqkv_a_scales_2], dim=0
+ return ExpertSelectUpGateSiLUWeightsConverter.tilert_to_tilert_144sm(
+ mat_in, mat_scale_in, mma_type
)
- wqkv_scales = torch.zeros(
- (18, 56), dtype=torch.bfloat16, device=tilert_wqkv_a_scales.device
+
+ @staticmethod
+ def tilert_to_tilert_144sm_mma(
+ mat_in: torch.Tensor,
+ mat_scale_in: torch.Tensor,
+ mma_type: str = "16x32",
+ ) -> torch.Tensor:
+ """Convert tilert weights and scales to tilert_144sm_mma input format."""
+ return ExpertSelectUpGateSiLUWeightsConverter.tilert_to_tilert_144sm_mma(
+ mat_in, mat_scale_in, mma_type
)
- for i in range(64):
- if ((i % 8) * 8 + i // 8) < 56:
- wqkv_scales[:, ((i % 8) * 8 + i // 8)] = wqkv_a_scales_swizzled[:, i]
- wq_a_scale = wqkv_scales[:12, :] # 12, 56
- wkv_a_scale = wqkv_scales[12:17, :] # 5, 56
- wk_scale = wqkv_scales[17:, :] # 1, 56
- attn_norm_weight = tilert_attn_norm_weight
- return wq_a, wq_a_scale, wkv_a, wkv_a_scale, wk, wk_scale, attn_norm_weight
+class UnProjOAllreduceWeightsConverter:
+ """Weights converter class."""
@staticmethod
- def common_to_tilert(
- wq_a: torch.Tensor,
- wq_a_scale: torch.Tensor,
- wkv_a: torch.Tensor,
- wkv_a_scale: torch.Tensor,
- wk: torch.Tensor,
- wk_scale: torch.Tensor,
- attn_norm_weight: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ def tilert_to_tilert_112sm_mma(
+ mat_in: torch.Tensor,
+ mat_scale_in: torch.Tensor,
+ ) -> torch.Tensor:
"""
- Convert common weights to tilert weights.
+ Convert tilert weights to tilert_112sm_mma input format.
Args:
- wq_a: Common weight tensor.
- wq_a_scale: Common weight scale tensor.
- wkv_a: Common weight tensor.
- wkv_a_scale: Common weight scale tensor.
- wk: Common weight tensor.
- wk_scale: Common weight scale tensor.
- attn_norm_weight: Common attention norm weight tensor.
+ mat_in: tilert weights, [7168, 2048]
+ mat_scale_in: tilert scales, [896, 16]
Returns:
- tuple: Tilert weights.
+ tilert_112sm weights, scales
"""
- wqkv_a = torch.cat([wq_a, wkv_a, wk], dim=0)
- wqkv_a_scales_raw = torch.cat([wq_a_scale, wkv_a_scale, wk_scale], dim=0)
+ swizzle_for_mma_16x32 = True
+ assert mat_in.shape == (7168, 2048)
+ assert mat_scale_in.shape == (896, 16)
+ weights_trt = mat_in.reshape(112, 64, 2048)
+ # to 64*512 blocks
+ weights_trt = weights_trt.reshape(112, 64, 4, 512).transpose(1, 2) # (112, 4, 64, 512)
+ if swizzle_for_mma_16x32:
+ # to (112, 4, 4(n), 16(k), 16, 32)
+ weights_trt = weights_trt.reshape(112, 4, 4, 16, 16, 32).transpose(-2, -3)
+ weights_trt = ExpertSelectUpGateSiLUWeightsConverter._swizzle_mma_16x32(weights_trt)
+ weights_trt = weights_trt.reshape(112, 4, 4 * 16 * 16 * 32)
- wqkv_a_scales = torch.zeros((18, 64), dtype=torch.bfloat16, device=wq_a_scale.device)
- for i in range(64):
- wqkv_a_scales[:, i] = wqkv_a_scales_raw[:, ((i % 8) * 8 + i // 8) % 56]
- if ((i % 8) * 8 + i // 8) >= 56:
- wqkv_a_scales[:, i] = 0.0
- wqkv_a_scales_0 = wqkv_a_scales[:16, :]
- wqkv_a_scales_1 = wqkv_a_scales[16:17, :]
- wqkv_a_scales_2 = wqkv_a_scales[17:, :]
-
- wqkv_a_scales_0 = wqkv_a_scales_0.reshape((16, 1, 64)).repeat(1, 8, 1).reshape(-1, 64)
- wqkv_a_scales = torch.cat([wqkv_a_scales_0, wqkv_a_scales_1, wqkv_a_scales_2], dim=0)
- assert wqkv_a_scales.shape == (130, 64)
- return wqkv_a.contiguous(), wqkv_a_scales.contiguous(), attn_norm_weight.clone()
+ # For scales
+ scales_trt = mat_scale_in.reshape(56, 16, 16)[:, 0, :]
+ scales_trt = scales_trt.reshape(56, 16)
+ return weights_trt.contiguous(), scales_trt.contiguous()
- @staticmethod
- def common_to_tilert_fp8(
- wq_a: torch.Tensor,
- wq_a_scale: torch.Tensor,
- wkv_a: torch.Tensor,
- wkv_a_scale: torch.Tensor,
- wk: torch.Tensor,
- wk_scale: torch.Tensor,
- attn_norm_weight: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Convert common weights to tilert weights.
- Args:
- wq_a: Common weight tensor.
- wq_a_scale: Common weight scale tensor.
- wkv_a: Common weight tensor.
- wkv_a_scale: Common weight scale tensor.
- wk: Common weight tensor.
- wk_scale: Common weight scale tensor.
- attn_norm_weight: Common attention norm weight tensor.
- Returns:
- tuple: Tilert fp8 weights.
- """
- wq_a_raw: torch.Tensor = wq_a.detach().clone()
- wkv_a_raw: torch.Tensor = wkv_a.detach().clone()
- wq_a_raw = torch.cat([wq_a_raw, wkv_a_raw[:512], wk, wkv_a_raw[512:]], dim=0)
-
- wq_a_raw = wq_a_raw.reshape(35, 64, 14, 512)
- wq_a_raw = wq_a_raw.permute(0, 2, 1, 3)
-
- wq_a_raw = wq_a_raw.reshape(35, 14, 16, 4, 4, 128)
- wq_a_copy = wq_a_raw.contiguous().clone()
- wq_a_raw[:, :, 1::2, :, :, :64] = wq_a_copy[:, :, 1::2, :, :, 64:]
- wq_a_raw[:, :, 1::2, :, :, 64:] = wq_a_copy[:, :, 1::2, :, :, :64]
- wq_a_raw = wq_a_raw.reshape(35, 14, 16, 4, 4, 2, 64)
- wq_a_copy = wq_a_raw.contiguous().clone()
- wq_a_raw[:, :, :, 2:, :, :, :32] = wq_a_copy[:, :, :, 2:, :, :, 32:]
- wq_a_raw[:, :, :, 2:, :, :, 32:] = wq_a_copy[:, :, :, 2:, :, :, :32]
- wq_a_raw = wq_a_raw.reshape(35, 14, 16, 4, 4, 2, 2, 32)
- wq_a_copy = wq_a_raw.contiguous().clone()
- wq_a_raw[:, :, :, 1::2, :, :, :, :16] = wq_a_copy[:, :, :, 1::2, :, :, :, 16:]
- wq_a_raw[:, :, :, 1::2, :, :, :, 16:] = wq_a_copy[:, :, :, 1::2, :, :, :, :16]
-
- wq_a_raw = wq_a_raw.reshape(35, 14, 16, 4, 4, 128)
- wq_a_raw = wq_a_raw.permute(0, 1, 4, 2, 3, 5).reshape(35, 14, -1).contiguous()
- wq_a_raw = wq_a_raw.reshape(35, 14, -1).contiguous()
-
- wq_s_raw: torch.Tensor = wq_a_scale.detach().clone()
- wkv_s_raw: torch.Tensor = wkv_a_scale.detach().clone()
- wq_s_raw = torch.cat([wq_s_raw, wkv_s_raw[:4], wk_scale, wkv_s_raw[4:]], dim=0)
- wq_s_raw = wq_s_raw.reshape(18, 1, 14, 4).repeat(1, 2, 1, 1).reshape(36, 1, 14, 4)
- wq_s_raw = wq_s_raw[:35].reshape(35, 14, -1).contiguous()
- wq_s_raw = wq_s_raw.view(torch.float8_e4m3fn)
- wq_as_raw = torch.cat([wq_a_raw, wq_s_raw], dim=-1)
-
- return wq_as_raw.contiguous(), attn_norm_weight.clone()
+class DownAllreduceWeightsConverter:
+ """Weights converter class."""
@staticmethod
- def common_to_tilert_native_bf16(
- wq_a: torch.Tensor,
- wq_a_scale: torch.Tensor,
- wkv_a: torch.Tensor,
- wkv_a_scale: torch.Tensor,
- wk: torch.Tensor,
- wk_scale: torch.Tensor,
- attn_norm_weight: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Convert common weights to weights for tilert native bf16 op.
-
- Args:
- wq_a: Common weight tensor.
- wq_a_scale: Common weight scale tensor.
- wkv_a: Common weight tensor.
- wkv_a_scale: Common weight scale tensor.
- wk: Common weight tensor.
- wk_scale: Common weight scale tensor.
- attn_norm_weight: Common attention norm weight tensor.
- Returns:
- tuple: Tilert weights for native bf16 op.
- """
- wq_a_scale = wq_a_scale.reshape((12, 56, 1)).repeat(1, 1, 128).reshape((12, 1, 7168))
- wq_a_scale = wq_a_scale.repeat(1, 128, 1).reshape((1536, 7168))
- wkv_a_scale = wkv_a_scale.reshape((5, 56, 1)).repeat(1, 1, 128).reshape((5, 1, 7168))
- wkv_a_scale = wkv_a_scale.repeat(1, 128, 1).reshape((-1, 7168))
- wkv_a_scale = wkv_a_scale[:576]
- wk_scale = wk_scale.reshape((1, 56, 1)).repeat(1, 1, 128).reshape((1, 1, 7168))
- wk_scale = wk_scale.repeat(1, 128, 1).reshape((128, 7168))
- wq_a = wq_a.reshape((1536, 7168)).float() * wq_a_scale.float()
- wkv_a = wkv_a.reshape((576, 7168)).float() * wkv_a_scale.float()
- wk = wk.reshape((128, 7168)).float() * wk_scale.float()
- weights = torch.cat([wq_a, wkv_a, wk], dim=0)
- assert weights.shape == (1536 + 576 + 128, 7168)
- return weights.to(torch.bfloat16).contiguous(), attn_norm_weight.clone()
+ def _swizzle_mma_16x32(mat_in: torch.Tensor) -> torch.Tensor:
+ assert mat_in.shape[-2] == 16 and mat_in.shape[-1] == 32
+ # PTX isa fig.88
+ pre_shape = mat_in.shape[:-2]
+ mat_in = mat_in.reshape(*pre_shape, 2, 8, 2, 4, 4).transpose(-4, -3).transpose(-5, -4)
+ return mat_in.reshape(*pre_shape, 2 * 2, 8 * 4, 4).transpose(-3, -2)
@staticmethod
- def common_to_tilert_native_bf16_warp_gemv(
- wq_a: torch.Tensor,
- wq_a_scale: torch.Tensor,
- wkv_a: torch.Tensor,
- wkv_a_scale: torch.Tensor,
- wk: torch.Tensor,
- wk_scale: torch.Tensor,
- attn_norm_weight: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Convert common weights to weights for tilert native bf16 warp gemv op.
-
- Args:
- wq_a: Common weight tensor.
- wq_a_scale: Common weight scale tensor.
- wkv_a: Common weight tensor.
- wkv_a_scale: Common weight scale tensor.
- wk: Common weight tensor.
- wk_scale: Common weight scale tensor.
- attn_norm_weight: Common attention norm weight tensor.
- Returns:
- tuple: Tilert weights for native bf16 warp gemv op.
- """
- wq_a_scale = wq_a_scale.reshape((12, 56, 1)).repeat(1, 1, 128).reshape((12, 1, 7168))
- wq_a_scale = wq_a_scale.repeat(1, 128, 1).reshape((1536, 7168))
- wkv_a_scale = wkv_a_scale.reshape((5, 56, 1)).repeat(1, 1, 128).reshape((5, 1, 7168))
- wkv_a_scale = wkv_a_scale.repeat(1, 128, 1).reshape((-1, 7168))
- wkv_a_scale = wkv_a_scale[:576]
- wk_scale = wk_scale.reshape((1, 56, 1)).repeat(1, 1, 128).reshape((1, 1, 7168))
- wk_scale = wk_scale.repeat(1, 128, 1).reshape((128, 7168))
- wq_a = wq_a.reshape((1536, 7168)).float() * wq_a_scale.float()
- wkv_a = wkv_a.reshape((576, 7168)).float() * wkv_a_scale.float()
- wk = wk.reshape((128, 7168)).float() * wk_scale.float()
- # concatenate the weights
- weights = torch.cat([wq_a, wkv_a, wk], dim=0)
- assert weights.shape == (1536 + 576 + 128, 7168)
-
- weights = weights.reshape(140, 16, 7, 1024)
- weights = weights.transpose(1, 2) # 140, 7, 16, 1024
- return weights.to(torch.bfloat16).contiguous(), attn_norm_weight.clone()
+ def _swizzle_mma_8x32(mat_in: torch.Tensor) -> torch.Tensor:
+ assert mat_in.shape[-2] == 8 and mat_in.shape[-1] == 32
+ # PTX isa fig.88
+ pre_shape = mat_in.shape[:-2]
+ return mat_in.reshape(*pre_shape, 8, 2, 4, 4).transpose(-2, -3).contiguous()
@staticmethod
- def common_to_tilert_dequant_bf16(
- wq_a: torch.Tensor,
- wq_a_scale: torch.Tensor,
- wkv_a: torch.Tensor,
- wkv_a_scale: torch.Tensor,
- wk: torch.Tensor,
- wk_scale: torch.Tensor,
- attn_norm_weight: torch.Tensor,
+ def tilert_to_tilert_mma(
+ mat_in: torch.Tensor, scale_in: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
- Convert common weights to weights for tilert dequant bf16 op.
-
- Args:
- wq_a: Common weight tensor.
- wq_a_scale: Common weight scale tensor.
- wkv_a: Common weight tensor.
- wkv_a_scale: Common weight scale tensor.
- wk: Common weight tensor.
- wk_scale: Common weight scale tensor.
- attn_norm_weight: Common attention norm weight tensor.
- Returns:
- tuple: Tilert weights for dequant bf16 op.
- """
- wq_a = wq_a.reshape((384, 4, 7168))
- wkv_a = wkv_a.reshape((144, 4, 7168))
- wk = wk.reshape((32, 4, 7168))
- wqkv = torch.cat([wq_a, wkv_a, wk], dim=0).reshape(140, 4, 4 * 7168)
-
- wq_a_scale = wq_a_scale.reshape((12, 1, 56)).repeat(1, 32, 1).reshape((384, 1, 56))
- wkv_a_scale = wkv_a_scale.reshape((5, 1, 56)).repeat(1, 32, 1).reshape((160, 1, 56))[:144]
- wk_scale = wk_scale.reshape((1, 1, 56)).repeat(1, 32, 1).reshape((32, 1, 56))
- wqkv_scales = torch.cat([wq_a_scale, wkv_a_scale, wk_scale], dim=0).reshape(140, 4, 56)
- wqkv_scales_swizzled = torch.zeros(140, 4, 64, dtype=torch.bfloat16, device=wq_a.device)
- # swizzle
- for i in range(64):
- wqkv_scales_swizzled[..., i] = wqkv_scales[..., ((i % 8) * 8 + i // 8) % 56]
- weights = torch.zeros(
- 140, 4, 4 * 7168 + 64 * 2, dtype=torch.float8_e4m3fn, device=wq_a.device
- )
- weights_part = weights[:, :, : 4 * 7168]
- scales_part = weights[:, :, 4 * 7168 :]
- weights_part.copy_(wqkv)
- scales_part.copy_(wqkv_scales_swizzled.view(dtype=torch.float8_e4m3fn))
- return weights.contiguous(), attn_norm_weight.clone()
-
-
-class RMSNormUpGateSiLUWeightsConverter:
- """Weights converter class."""
-
- @staticmethod
- def tilert_to_tilert_144sm(mat_in: torch.Tensor, mat_scale_in: torch.Tensor) -> torch.Tensor:
- """
- Convert tilert weights and scales to tilert_144sm input format.
+ Convert tilert weights and scales to tilert_mma input format.
Args:
mat_in: tilert weights
mat_scale_in: tilert scales
Returns:
- tilert_144sm weights and scales
+ tilert_mma weights and scales
"""
- return ExpertSelectUpGateSiLUWeightsConverter.tilert_to_tilert_144sm(mat_in, mat_scale_in)
+ exp_num = mat_in.shape[0]
+ mat_in_s = mat_in.reshape(exp_num, 128, 56, 256)
+ # mat_in_s[:, :, 48:56] = 0;
+ mat_in_0 = mat_in_s[:, :, :16].reshape(exp_num, 128, 16, 8, 32).transpose(2, 3)
+ mat_in_0 = DownAllreduceWeightsConverter._swizzle_mma_16x32(mat_in_0).reshape(
+ exp_num, 128, -1
+ )
+ mat_in_1 = mat_in_s[:, :, 16:32].reshape(exp_num, 128, 16, 8, 32).transpose(2, 3)
+ mat_in_1 = DownAllreduceWeightsConverter._swizzle_mma_16x32(mat_in_1).reshape(
+ exp_num, 128, -1
+ )
+ mat_in_2 = mat_in_s[:, :, 32:48].reshape(exp_num, 128, 16, 8, 32).transpose(2, 3)
+ mat_in_2 = DownAllreduceWeightsConverter._swizzle_mma_16x32(mat_in_2).reshape(
+ exp_num, 128, -1
+ )
+ mat_in_3 = mat_in_s[:, :, 48:56].reshape(exp_num, 128, 8, 8, 32).transpose(2, 3)
+ mat_in_3 = DownAllreduceWeightsConverter._swizzle_mma_8x32(mat_in_3).reshape(
+ exp_num, 128, -1
+ )
+
+ mat_in_swizzled = torch.cat([mat_in_0, mat_in_1, mat_in_2, mat_in_3], dim=2)
+ return mat_in_swizzled.reshape(exp_num, 7168, 256).contiguous(), scale_in
diff --git a/python/models/utils.py b/python/models/utils.py
index 123176d..e85242b 100644
--- a/python/models/utils.py
+++ b/python/models/utils.py
@@ -126,6 +126,7 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso
return y_out.to(dtype)
+# enumerate swizzle mode
class SwizzleMode(IntEnum):
"""Swizzle mode."""
@@ -135,6 +136,7 @@ class SwizzleMode(IntEnum):
SWIZZLE_128B = 128 // 16
+# See CUDA C++ programming Guide 10.29.3.2 for more details.
def gen_tensor_swizzle_map_1d(
rows: int, cols_in_16bytes: int, swizzle_mode: SwizzleMode = SwizzleMode.SWIZZLE_128B
) -> torch.Tensor:
diff --git a/python/utils.py b/python/utils.py
index c1662e5..47335d7 100644
--- a/python/utils.py
+++ b/python/utils.py
@@ -48,7 +48,9 @@ def cosine_similarity(gt: torch.Tensor, out: torch.Tensor) -> torch.Tensor:
Returns:
The cosine similarity.
"""
- return torch.nn.functional.cosine_similarity(gt.flatten(), out.flatten(), dim=-1)
+ return torch.nn.functional.cosine_similarity(
+ gt.flatten().float(), out.flatten().float(), dim=-1
+ )
def relative_l2_error(gt: torch.Tensor, out: torch.Tensor) -> Any:
diff --git a/requirements-ci.txt b/requirements-ci.txt
deleted file mode 100644
index c8fb54f..0000000
--- a/requirements-ci.txt
+++ /dev/null
@@ -1,39 +0,0 @@
-# --- Linting ---
-black==25.1.0
-isort==6.0.1
-flake8==7.1.2
-flake8-bugbear==24.12.12
-flake8-comprehensions==3.16.0
-flake8-docstrings==1.7.0
-flake8-simplify==0.21.0
-flake8-unused-arguments==0.0.13
-flake8-variables-names==0.0.6
-flake8-return==1.2.0
-flake8-print==5.0.0
-mypy==1.15.0
-tomli==2.2.1
-bandit==1.8.3
-pyupgrade==3.19.1
-clang-format==18.1.5
-types-setuptools
-types-requests
-types-urllib3
-types-six
-
-# --- Docs formatting (if needed for CI) ---
-mdformat==0.7.17
-mdformat-gfm==0.4.1
-mdformat-frontmatter==2.0.8
-mdformat-myst==0.2.1
-mdformat-tables==1.0.0
-mdformat-toc==0.3.0
-mdformat-black==0.1.1
-
-# --- Pre-commit ---
-pre-commit>=3.0.0
-
-# --- Commitizen ---
-commitizen==4.4.1
-
-# --- Codespell ---
-codespell==2.4.1