From c92f2b55c88d6c86469eb4f053df5a2c8b7aaca5 Mon Sep 17 00:00:00 2001 From: Peng Sun Date: Sat, 7 Feb 2026 21:28:54 -0600 Subject: [PATCH 1/8] Add comprehensive documentation guides and improve README Create 8 documentation guides covering all major ATOM subsystems: - Architecture & Design (request lifecycle, component diagram) - Configuration & CLI Reference (all config classes, env vars) - Model Support (8 architectures, weight loading, adding new models) - Model Operations & AITER Integration (kernel mapping, fused ops) - Scheduling & KV Cache (prefill-first scheduler, block manager, prefix caching) - Compilation & CUDA Graphs (4 levels, 5 modes, piecewise compilation) - Distributed Inference (TP, DP, EP with MORI all-to-all) - Serving & Benchmarks (OpenAI server, profiling, MTP speculative decoding) All guides are fact-checked against the codebase. README updated with expanded features (OpenAI API, quantization, multi-GPU, speculative decoding, prefix caching), supported models table, documentation links, and improved section structure. --- README.md | 129 ++++-- docs/architecture_guide.md | 240 +++++++++++ docs/compilation_cudagraph_guide.md | 515 ++++++++++++++++++++++ docs/configuration_guide.md | 356 +++++++++++++++ docs/distributed_guide.md | 454 +++++++++++++++++++ docs/model_ops_guide.md | 542 +++++++++++++++++++++++ docs/model_support_guide.md | 320 ++++++++++++++ docs/scheduling_kv_cache_guide.md | 595 +++++++++++++++++++++++++ docs/serving_benchmarking_guide.md | 648 ++++++++++++++++++++++++++++ 9 files changed, 3756 insertions(+), 43 deletions(-) create mode 100644 docs/architecture_guide.md create mode 100644 docs/compilation_cudagraph_guide.md create mode 100644 docs/configuration_guide.md create mode 100644 docs/distributed_guide.md create mode 100644 docs/model_ops_guide.md create mode 100644 docs/model_support_guide.md create mode 100644 docs/scheduling_kv_cache_guide.md create mode 100644 docs/serving_benchmarking_guide.md diff --git a/README.md b/README.md index a129f2491..92139a4b2 100644 --- a/README.md +++ b/README.md @@ -4,13 +4,30 @@ -------------------------------------------------------------------------------- -**ATOM** (AiTer Optimized Model) is a lightweight vLLM-like implementation, focusing on integration and optimization based on [aiter](https://github.com/ROCm/aiter). +**ATOM** (AiTer Optimized Model) is a lightweight vLLM-like implementation, focusing on integration and optimization based on [AITER](https://github.com/ROCm/aiter). ## 🚀 Features -- **ROCm Optimized**: Built on AMD's ROCm platform with torch compile support -- **Model Support**: Compatible with **[Deepseek](https://huggingface.co/deepseek-ai)**, **[Qwen](https://huggingface.co/Qwen)**, **[Llama](https://huggingface.co/meta-llama)**, **[Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**, **[GPTOSS](https://huggingface.co/openai)**. -- **Easy Integration**: Simple API for quick deployment +- **ROCm Optimized**: Built on AMD's ROCm platform with [AITER](https://github.com/ROCm/aiter) kernels (ASM, CK, Triton) +- **OpenAI-Compatible API**: Drop-in server with `/v1/chat/completions` and `/v1/completions` endpoints +- **Piecewise torch.compile**: 4 compilation levels with CUDA graph capture for low-latency decode +- **Multi-GPU Parallelism**: Tensor parallelism (TP), data parallelism (DP), and expert parallelism (EP) with MORI all-to-all +- **Quantization**: FP8, MXFP4, INT8, INT4 with auto-detection from HuggingFace configs +- **Speculative Decoding**: Multi-Token Prediction (MTP) with EAGLE proposer +- **Prefix Caching**: xxhash64-based KV cache block sharing across sequences + +### Supported Models + +| Model Family | HF Architecture | Dense/MoE | Notes | +|---|---|---|---| +| [Llama](https://huggingface.co/meta-llama) | `LlamaForCausalLM` | Dense | Llama 2, Llama 3, Llama 3.1 | +| [Qwen3](https://huggingface.co/Qwen) | `Qwen3ForCausalLM` | Dense | | +| [Qwen3-MoE](https://huggingface.co/Qwen) | `Qwen3MoeForCausalLM` | MoE | 128 experts, top-8 routing | +| [DeepSeek V2/V3](https://huggingface.co/deepseek-ai) | `DeepseekV3ForCausalLM` | MoE | MLA attention, MTP speculative decoding | +| [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) | `MixtralForCausalLM` | MoE | 8 experts, top-2 routing | +| [GLM-4-MoE](https://huggingface.co/THUDM) | `Glm4MoeForCausalLM` | MoE | | +| [GPT-OSS](https://huggingface.co/openai) | `GptOssForCausalLM` | Dense | Sliding window + attention sinks | +| [Kimi-K2](https://huggingface.co/moonshotai/Kimi-K2-Thinking) | via `--trust-remote-code` | MoE | See [recipe](recipes/Kimi-K2-Thinking.md) | ## 📋 Requirements @@ -50,48 +67,75 @@ pip install amd-aiter git clone https://github.com/ROCm/ATOM.git; cd ./ATOM; pip install . ``` +## 📚 Documentation + +| **Topic** | **Description** | **Guide** | +|---|---|---| +| Architecture | System overview, request lifecycle, component design | [Architecture Guide](docs/architecture_guide.md) | +| Configuration | Config classes, CLI arguments, environment variables | [Configuration Guide](docs/configuration_guide.md) | +| Model Support | Supported models, weight loading, adding new architectures | [Model Support Guide](docs/model_support_guide.md) | +| Model Operations | AITER kernel integration, linear/attention/MoE/norm wrappers | [Model Ops Guide](docs/model_ops_guide.md) | +| Scheduling & KV Cache | Batch scheduling, block allocation, prefix caching | [Scheduling Guide](docs/scheduling_kv_cache_guide.md) | +| Compilation | torch.compile levels, CUDA graphs, piecewise compilation | [Compilation Guide](docs/compilation_cudagraph_guide.md) | +| Distributed | Tensor/data/expert parallelism, multi-GPU deployment | [Distributed Guide](docs/distributed_guide.md) | +| Serving & Benchmarks | OpenAI API server, benchmarking, profiling, speculative decoding | [Serving Guide](docs/serving_benchmarking_guide.md) | + +**Deployment Recipes:** + +- [Qwen3-235B-A22B](recipes/Qwen3-235b.md) -- TP8 + EP with FP8 KV cache +- [Kimi-K2-Thinking](recipes/Kimi-K2-Thinking.md) -- MXFP4 MoE on 4 GPUs + ## 💡 Usage ### Basic Example -The default optimization level is 3 (running with torch compile). Supported models include **[Deepseek](https://huggingface.co/deepseek-ai)**, **[Qwen](https://huggingface.co/Qwen)**, **[Llama](https://huggingface.co/meta-llama)**, **[Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**, **[GPTOSS](https://huggingface.co/openai)**. +The default optimization level is 3 (piecewise torch.compile with CUDA graphs). ```bash -python -m atom.examples.simple_inference --model meta-llama/Meta-Llama-3-8B --kv_cache_dtype fp8 +python -m atom.examples.simple_inference --model meta-llama/Meta-Llama-3-8B --kv_cache_dtype fp8 ``` > **Note:** First-time execution may take approximately 10 minutes for model compilation. -### Performance profiling +### Serving + +Start an OpenAI-compatible server: -Profile offline inference ```bash -python -m atom.examples.profile_offline --model Qwen/Qwen3-0.6B --kv_cache_dtype fp8 +# Single GPU +python -m atom.entrypoints.openai_server --model Qwen/Qwen3-0.6B --kv_cache_dtype fp8 + +# Multi-GPU with tensor parallelism +python -m atom.entrypoints.openai_server --model deepseek-ai/DeepSeek-R1 --kv_cache_dtype fp8 -tp 8 ``` -Or profile offline with custom input length + +### Profiling + +Profile offline inference: + ```bash -python -m atom.examples.profile_offline --model Qwen/Qwen3-0.6B --kv_cache_dtype fp8 --random-input --input-length 1024 --output-length 32 +python -m atom.examples.profile_offline --model Qwen/Qwen3-0.6B --kv_cache_dtype fp8 ``` -Profile online inference +With custom input/output lengths: + ```bash -curl -s -S -X POST http://127.0.0.1:8000/start_profile +python -m atom.examples.profile_offline --model Qwen/Qwen3-0.6B --kv_cache_dtype fp8 \ + --random-input --input-length 1024 --output-length 32 ``` -Run your task + +Profile a running server: + ```bash +curl -s -S -X POST http://127.0.0.1:8000/start_profile +# ... run your workload ... curl -s -S -X POST http://127.0.0.1:8000/stop_profile ``` -### Performance Benchmarking +### Benchmarking -Run online throughput benchmark: +Run an online throughput benchmark against a running server: -start the server -```bash -python -m atom.entrypoints.openai_server --kv_cache_dtype fp8 --model Qwen/Qwen3-0.6B -python -m atom.entrypoints.openai_server --kv_cache_dtype fp8 -tp 8 --model deepseek-ai/DeepSeek-R1 -``` -run benchmark ```bash MODEL=deepseek-ai/DeepSeek-R1 ISL=1024 @@ -99,53 +143,52 @@ OSL=1024 CONC=128 PORT=8000 RESULT_FILENAME=Deepseek-R1-result - + python -m atom.benchmarks.benchmark_serving \ ---model=$MODEL --backend=vllm --base-url=http://localhost:$PORT \ ---dataset-name=random \ ---random-input-len=$ISL --random-output-len=$OSL \ ---random-range-ratio 0.8 \ ---num-prompts=$(( $CONC * 10 )) \ ---max-concurrency=$CONC \ ---request-rate=inf --ignore-eos \ ---save-result --percentile-metrics="ttft,tpot,itl,e2el" \ ---result-dir=./ --result-filename=$RESULT_FILENAME.json + --model=$MODEL --backend=vllm --base-url=http://localhost:$PORT \ + --dataset-name=random \ + --random-input-len=$ISL --random-output-len=$OSL \ + --random-range-ratio 0.8 \ + --num-prompts=$(( $CONC * 10 )) \ + --max-concurrency=$CONC \ + --request-rate=inf --ignore-eos \ + --save-result --percentile-metrics="ttft,tpot,itl,e2el" \ + --result-dir=./ --result-filename=$RESULT_FILENAME.json ``` - ## 📊 Performance -### Online serving throughput +### Online Serving Throughput ![DS R1 Performance](./docs/ds_r1_performance.png) For more information, visit [InferenceMAX](https://inferencemax.semianalysis.com/). -### Accuracy Benchmarking +### Accuracy Validation -First, install `lm-eval` to test model accuracy: +Install `lm-eval` to test model accuracy: ```bash pip install lm-eval[api] ``` -Next, start an OpenAI-compatible server using `openai_server.py`: +Start a server, then run the evaluation: ```bash -python -m atom.entrypoints.openai_server --model meta-llama/Meta-Llama-3-8B --kv_cache_dtype fp8 +python -m atom.entrypoints.openai_server --model meta-llama/Meta-Llama-3-8B --kv_cache_dtype fp8 ``` -Finally, run the evaluation by choosing your datasets: - ```bash lm_eval --model local-completions \ --model_args model=meta-llama/Meta-Llama-3-8B,base_url=http://localhost:8000/v1/completions,num_concurrent=64,max_retries=3,tokenized_requests=False \ --tasks gsm8k \ --num_fewshot 5 ``` + ## Acknowledgements -This project was adapted from nano-vllm (https://github.com/GeeeekExplorer/nano-vllm) -## Support & Reporting Issues +This project was adapted from [nano-vllm](https://github.com/GeeeekExplorer/nano-vllm). + +## Support & Reporting Issues -We welcome issues and contributions! Please use the GitHub Issues page to report bugs or request features: https://github.com/ROCm/ATOM/issues +We welcome issues and contributions! Please use the GitHub Issues page to report bugs or request features: https://github.com/ROCm/ATOM/issues diff --git a/docs/architecture_guide.md b/docs/architecture_guide.md new file mode 100644 index 000000000..e51064db2 --- /dev/null +++ b/docs/architecture_guide.md @@ -0,0 +1,240 @@ +# ATOM Architecture Guide + +> **Quick Reference** +> +> | Class | Import | Purpose | +> |-------|--------|---------| +> | `LLMEngine` | `from atom.model_engine.llm_engine import LLMEngine` | User-facing inference API | +> | `InputOutputProcessor` | `from atom.model_engine.llm_engine import InputOutputProcessor` | Tokenize/detokenize, TTFT/TPOT stats | +> | `CoreManager` | `from atom.model_engine.engine_core_mgr import CoreManager` | Multi-process orchestration via ZMQ | +> | `EngineCore` | `from atom.model_engine.engine_core import EngineCore` | Per-process engine loop | +> | `DPEngineCoreProc` | `from atom.model_engine.engine_core import DPEngineCoreProc` | Data-parallel engine core variant | +> | `ModelRunner` | `from atom.model_engine.model_runner import ModelRunner` | Per-GPU model execution | +> | `Scheduler` | `from atom.model_engine.scheduler import Scheduler` | Prefill-first request scheduling | +> | `BlockManager` | `from atom.model_engine.block_manager import BlockManager` | KV cache block allocation | +> | `Sequence` | `from atom.model_engine.sequence import Sequence` | Request state and token tracking | +> | `ForwardContext` | `from atom.utils.forward_context import ForwardContext` | Global forward pass metadata | +> | `Config` | `from atom.config import Config` | Master configuration dataclass | + +--- + +## 1. System Overview + +ATOM (AiTer Optimized Model) is AMD's lightweight LLM inference engine, inspired by vLLM's architecture and built on the [AITER](https://github.com/ROCm/aiter) kernel library for ROCm/HIP GPUs. + +Key design principles: + +- **Multi-process architecture** -- each engine core runs in its own process, with ZMQ-based IPC connecting the user-facing API to one or more GPU workers. +- **AITER-native execution** -- model forward passes use AITER's optimized attention, MoE, sampling, and communication kernels rather than generic PyTorch operators. +- **CUDA graph acceleration** -- decode batches are captured into CUDA graphs for replay, eliminating per-step kernel launch overhead. +- **Prefill-first scheduling** -- the scheduler prioritizes prompt prefills before decode steps, following vLLM's continuous batching strategy. +- **Speculative decoding** -- optional EAGLE/MTP (Multi-Token Prediction) draft models propose tokens that are verified via rejection sampling. + +--- + +## 2. Component Architecture + +``` +LLMEngine (user-facing API) +├── InputOutputProcessor (tokenize/detokenize, TTFT/TPOT stats) +├── CoreManager (multi-process orchestration via ZMQ) +│ └── EngineCore (one per DP rank, runs in its own process) +│ ├── ModelRunner (per-GPU execution via AsyncIOProcManager) +│ │ ├── Model (Qwen3, Llama, DeepSeek, Mixtral, etc.) +│ │ ├── Sampler / RejectionSampler +│ │ └── EagleProposer (optional MTP draft) +│ └── Scheduler +│ └── BlockManager (KV cache block management) +└── Config (master configuration) +``` + +**Supported model architectures** (registered in `support_model_arch_dict`, a module-level dict in `model_runner.py`): + +| Architecture key | Implementation | +|---|---| +| `Qwen3ForCausalLM` | `atom.models.qwen3.Qwen3ForCausalLM` | +| `Qwen3MoeForCausalLM` | `atom.models.qwen3_moe.Qwen3MoeForCausalLM` | +| `LlamaForCausalLM` | `atom.models.llama.LlamaForCausalLM` | +| `MixtralForCausalLM` | `atom.models.mixtral.MixtralForCausalLM` | +| `DeepseekV3ForCausalLM` | `atom.models.deepseek_v2.DeepseekV2ForCausalLM` | +| `DeepseekV32ForCausalLM` | `atom.models.deepseek_v2.DeepseekV2ForCausalLM` | +| `GptOssForCausalLM` | `atom.models.gpt_oss.GptOssForCausalLM` | +| `Glm4MoeForCausalLM` | `atom.models.glm4_moe.Glm4MoeForCausalLM` | + +--- + +## 3. Request Lifecycle + +A request flows through the system in ten steps: + +1. **`LLMEngine.add_request()` / `generate()`** -- the user submits a list of prompts (strings or pre-tokenized token IDs) together with `SamplingParams`. + +2. **`InputOutputProcessor.preprocess()`** -- each prompt is tokenized via the HuggingFace tokenizer. A `Sequence` object is created to track the request's state, timing, and block allocation. `arrive_time` is recorded. + +3. **`CoreManager.add_request()`** -- the list of `Sequence` objects is serialized with `pickle` and sent over a ZMQ `ROUTER` socket. When multiple DP ranks are active, requests are distributed round-robin. + +4. **`EngineCore.process_input_sockets()`** -- an I/O thread on the `EngineCore` process receives the serialized data on a ZMQ `DEALER` socket, deserializes it, and places the sequences into the `input_queue`. + +5. **`EngineCore.busy_loop()`** -- the main execution loop pulls from `input_queue` via `pull_and_process_input_queue()`, feeds new sequences into the scheduler, and repeatedly calls `_process_engine_step()` until all work is done. + +6. **`Scheduler.schedule()`** -- implements prefill-first scheduling. Waiting sequences are scheduled for prefill if they fit within `max_num_seqs` and `max_num_batched_tokens` and the `BlockManager` can allocate blocks. If no prefills are pending, running sequences are batched for decode. The scheduler returns a `ScheduledBatch` and the corresponding sequence map. + +7. **`ModelRunner.forward()`** -- executes the three-phase forward pass: + - `prepare_model()` -- assembles input IDs (handling deferred output from previous steps), builds attention metadata, and gathers sampling temperatures. + - `run_model()` -- runs the model forward. Prefill and large batches run eagerly; decode batches replay captured CUDA graphs. Returns logits and hidden states. + - `postprocess()` -- samples tokens (or runs rejection sampling for speculative decoding), prepares deferred output via `tokenIDProcessor`, and optionally proposes draft tokens through `EagleProposer`. + +8. **`Scheduler.postprocess()`** -- appends sampled tokens to each `Sequence`, records `first_token_time`, checks stop conditions (EOS, stop token IDs, stop token sequences, `max_tokens`), and moves finished sequences out of the running queue. The `BlockManager` deallocates blocks for finished sequences. + +9. **Output via ZMQ** -- finished sequences are placed on the `output_queue`. A dedicated output thread serializes them and sends them over a ZMQ `PUSH` socket back to the `CoreManager`, which receives them on a `PULL` socket and places them in `outputs_queue`. + +10. **`InputOutputProcessor.postprocess()`** -- detokenizes completed sequences, computes TTFT (Time To First Token) and TPOT (Time Per Output Token), and returns structured output dictionaries. + +--- + +## 4. Forward Context Pattern + +ATOM uses a module-level global `ForwardContext` to pass metadata through CUDA graph boundaries without threading it as function parameters. + +**Core dataclasses** (defined in `atom/utils/forward_context.py`): + +- **`ForwardContext`** -- top-level container holding: + - `attn_metadata` (`AttentionMetaData`) -- cumulative sequence lengths, block tables, slot mappings, and backend-specific metadata. + - `context` (`Context`) -- positions, prefill flag, batch size, graph batch size, draft flag. + - `dp_metadata` (`DPMetadata`) -- cross-DP-rank token counts and cumulative sums. + - `spec_decode_metadata` (`SpecDecodeMetadata`) -- draft token IDs, target/bonus logits indices. + - `kv_cache_data` (`dict[str, KVCacheTensor]`) -- per-layer KV cache tensor references. + +- **`Context`** -- lightweight struct: `positions`, `is_prefill`, `batch_size`, `graph_bs`, `is_draft`. + +- **`DPMetadata`** -- data parallel metadata with `num_tokens_across_dp()` (all-reduce), `max_tokens_across_dp`, and `chunked_sizes()` context manager. + +**Global accessors:** + +| Function | Purpose | +|---|---| +| `set_forward_context(attn_metadata, atom_config, context, ...)` | Set the global context before a forward pass | +| `get_forward_context()` | Retrieve the current context (used by attention backends) | +| `reset_forward_context()` | Clear after forward pass completes | +| `set_kv_cache_data(kv_cache_data)` | Register KV cache tensors at initialization | + +This pattern enables stateless dispatch: attention backends and model operators call `get_forward_context()` to access metadata without requiring it as a function parameter, which is critical for CUDA graph compatibility. + +--- + +## 5. Multi-Process Architecture + +ATOM uses a multi-process design with ZMQ sockets for inter-process communication: + +``` + ┌──────────────────────────────────┐ + │ LLMEngine │ + │ ┌────────────────────────────┐ │ + │ │ CoreManager │ │ + │ │ │ │ + │ │ ROUTER ──────► DEALER │ │ + │ │ (input) (per rank) │ │ + │ │ │ │ + │ │ PULL ◄─────── PUSH │ │ + │ │ (output) (per rank) │ │ + │ └────────────────────────────┘ │ + └──────────────────────────────────┘ + │ ▲ + pickle │ │ pickle + ▼ │ + ┌──────────────────────────────────────┐ + │ EngineCore (Process) │ + │ │ + │ input_queue ──► busy_loop │ + │ │ │ + │ ┌─────────────────▼───────────────┐ │ + │ │ AsyncIOProcManager │ │ + │ │ ┌────────────────────────────┐ │ │ + │ │ │ ModelRunner (TP rank 0) │ │ │ + │ │ │ ModelRunner (TP rank 1) │ │ │ + │ │ │ ... │ │ │ + │ │ └────────────────────────────┘ │ │ + │ └──────────────────────────────────┘ │ + │ │ + │ Scheduler + BlockManager │ + └──────────────────────────────────────┘ +``` + +**Socket types:** + +| Socket | Type | Direction | Purpose | +|---|---|---|---| +| Input | `ROUTER` (CoreManager) / `DEALER` (EngineCore) | CoreManager -> EngineCore | Send requests and control commands | +| Output | `PUSH` (EngineCore) / `PULL` (CoreManager) | EngineCore -> CoreManager | Return finished sequences and stream outputs | + +**Process hierarchy:** + +- **`CoreManager`** spawns one `EngineCore` process per DP rank using `multiprocessing.Process`. +- Each **`EngineCore`** creates an `AsyncIOProcManager`, which in turn spawns one subprocess per TP rank. +- Each **`ModelRunner`** subprocess initializes AITER's distributed environment via `init_dist_env()` from AITER, setting up NCCL communication across TP ranks. + +**Data-parallel variant** (`DPEngineCoreProc`): + +When `data_parallel_size > 1`, each EngineCore process is a `DPEngineCoreProc` that synchronizes with other DP ranks via `torch.distributed.all_reduce` on a Gloo process group. The `busy_loop()` override ensures all DP ranks stay in lockstep: if one rank has a prefill batch while another does not, the idle rank executes a dummy prefill (`dummy_prefill_execution()`) to keep NCCL collectives synchronized. + +--- + +## 6. Sequence Lifecycle + +The `Sequence` class (in `atom/model_engine/sequence.py`) is the central data structure tracking a single request through the engine. + +**Key fields:** + +| Field | Type | Purpose | +|---|---|---| +| `id` | `int` | Auto-incrementing unique identifier | +| `token_ids` | `list[int]` | Full token sequence (prompt + generated) | +| `num_prompt_tokens` | `int` | Length of the original prompt | +| `num_tokens` | `int` (property) | Total length including generated tokens | +| `block_table` | `list[int]` | KV cache block IDs allocated to this sequence | +| `status` | `SequenceStatus` | Current lifecycle state | +| `type` | `SequenceType` | Current execution type | +| `temperature` | `float` | Sampling temperature | +| `max_tokens` | `int` | Maximum completion length | +| `arrive_time` | `float` | Timestamp when request entered the system | +| `first_token_time` | `float` | Timestamp of first generated token (for TTFT) | +| `leave_time` | `float` | Timestamp when request finished (for TPOT) | +| `spec_token_ids` | `list[int]` | Speculative/draft token IDs for MTP | +| `stream_callback` | `Callable` | Optional per-token streaming callback | + +**Status transitions:** + +``` +WAITING ──(scheduled for prefill)──► RUNNING ──(stop condition met)──► FINISHED + ▲ │ + └────────(preempted by scheduler)────┘ +``` + +- `SequenceStatus.WAITING` -- queued in the scheduler's waiting deque, awaiting block allocation. +- `SequenceStatus.RUNNING` -- actively being processed (prefill or decode). +- `SequenceStatus.FINISHED` -- stop condition met (EOS, stop token, stop sequence, or `max_tokens`). Blocks are deallocated. +- `SequenceStatus.EXIT_ENGINE` -- sentinel status used to signal engine shutdown. + +**Execution types:** + +- `SequenceType.DUMMY` -- initial state before scheduling. +- `SequenceType.PREFILL` -- prompt processing phase (all prompt tokens in one batch). +- `SequenceType.DECODE` -- autoregressive token generation (one or more tokens per step with MTP). + +--- + +## Source Files + +| File | Description | +|------|-------------| +| `atom/model_engine/llm_engine.py` | `LLMEngine` user-facing API, `InputOutputProcessor` for tokenization/detokenization and TTFT/TPOT statistics | +| `atom/model_engine/engine_core.py` | `EngineCore` main execution loop, `DPEngineCoreProc` data-parallel variant, `EngineCoreRequestType` message protocol | +| `atom/model_engine/engine_core_mgr.py` | `CoreManager` ZMQ orchestration, process launching, round-robin DP dispatch | +| `atom/model_engine/model_runner.py` | `ModelRunner` per-GPU execution (model loading, CUDA graph capture, forward pass), `tokenIDProcessor` deferred output handling | +| `atom/model_engine/scheduler.py` | `Scheduler` prefill-first scheduling, `ScheduledBatch` batch descriptor, `ScheduledBatchOutput` forward results | +| `atom/model_engine/sequence.py` | `Sequence` request state, `SequenceStatus` and `SequenceType` enums | +| `atom/model_engine/block_manager.py` | `BlockManager` KV cache block allocation with optional prefix caching | +| `atom/model_engine/request.py` | `RequestOutput` dataclass for streaming callbacks | +| `atom/model_engine/async_proc.py` | `AsyncIOProcManager` and `AsyncIOProc` for spawning and managing ModelRunner subprocesses | +| `atom/utils/forward_context.py` | `ForwardContext`, `Context`, `DPMetadata`, `SpecDecodeMetadata`, `AttentionMetaData` dataclasses and global accessors | +| `atom/config.py` | `Config` master configuration, `ParallelConfig`, `CompilationConfig`, `QuantizationConfig`, `SpeculativeConfig`, `KVCacheTensor` | diff --git a/docs/compilation_cudagraph_guide.md b/docs/compilation_cudagraph_guide.md new file mode 100644 index 000000000..bbb381f52 --- /dev/null +++ b/docs/compilation_cudagraph_guide.md @@ -0,0 +1,515 @@ +# ATOM Compilation & CUDA Graphs Guide + +> **Quick Reference** +> +> | Concept | Key Class / Enum | Import | +> |---------|-----------------|--------| +> | Compilation Levels | `CompilationLevel` | `from atom.config import CompilationLevel` | +> | Compilation Config | `CompilationConfig` | `from atom.config import CompilationConfig` | +> | CUDA Graph Modes | `CUDAGraphMode` | `from atom.config import CUDAGraphMode` | +> | CUDA Graph Wrapper | `CUDAGraphWrapper` | `from atom.utils.cuda_graph import CUDAGraphWrapper` | +> | Forward Context | `ForwardContext` | `from atom.utils.forward_context import ForwardContext` | +> | Compiler Backend | `VllmBackend` | `from atom.utils.backends import VllmBackend` | +> | Compiler Manager | `CompilerManager` | `from atom.utils.backends import CompilerManager` | +> | Compiler Interface | `CompilerInterface` | `from atom.utils.compiler_inferface import CompilerInterface` | +> | Inductor Adaptor | `InductorAdaptor` | `from atom.utils.compiler_inferface import InductorAdaptor` | +> | Piecewise Backend | `PiecewiseBackend` | `from atom.utils.cuda_piecewise_backend import PiecewiseBackend` | +> | Compile Decorator | `@support_torch_compile` | `from atom.utils.decorators import support_torch_compile` | +> | Custom Op Registration | `direct_register_custom_op` | `from atom.utils.custom_register import direct_register_custom_op` | +> +> **Compilation Levels at a Glance** +> +> | Level | Name | Behavior | +> |-------|------|----------| +> | 0 | `NO_COMPILATION` | Pure eager execution, no `torch.compile` | +> | 1 | `DYNAMO_AS_IS` | `torch.compile` with `backend="eager"` | +> | 2 | `DYNAMO_ONCE` | `torch.compile` with Inductor | +> | 3 | `PIECEWISE` | Piecewise compilation with CUDA graph capture (production default) | +> +> **CUDA Graph Modes at a Glance** +> +> | Mode | Value | Behavior | +> |------|-------|----------| +> | `NONE` | `0` | No graph capture | +> | `PIECEWISE` | `1` | Per-subgraph capture (default for level 3) | +> | `FULL` | `2` | Whole-model capture | +> | `FULL_DECODE_ONLY` | `(FULL, NONE)` | Full for decode, none for mixed batches | +> | `FULL_AND_PIECEWISE` | `(FULL, PIECEWISE)` | Full for decode, piecewise for prefill | + +--- + +## 1. Compilation Levels + +ATOM provides four compilation levels via the `CompilationLevel` class in `atom/config.py`. The level is set through `CompilationConfig.level` and controls how `torch.compile` is applied to the model. + +### Level 0 -- NO_COMPILATION + +No `torch.compile` is applied. The model runs in pure eager mode. This is the simplest mode and is useful for debugging or when using models that are incompatible with `torch.compile`. + +When `level=0`, the `@support_torch_compile` decorator sets `self.do_not_compile = True` and the model's `__call__` method bypasses compilation entirely, calling `self.forward()` directly. + +### Level 1 -- DYNAMO_AS_IS + +Uses `torch.compile` with `backend="eager"` and `fullgraph=True`. This runs Dynamo's bytecode analysis and graph capture but does not apply any compiler optimizations. It is useful as a quick check to verify that a model is compatible with Dynamo's tracing. + +Like level 0, `DYNAMO_AS_IS` causes the decorator to set `self.do_not_compile = True`, since the model runner (rather than the decorator) handles the compilation at this level. + +### Level 2 -- DYNAMO_ONCE + +Uses `torch.compile` with the Inductor backend. The model graph is traced by Dynamo and compiled once through Inductor for optimized GPU kernel generation. The `@support_torch_compile` decorator's custom dispatcher is activated when `compilation_level >= DYNAMO_ONCE`, allowing compiled bytecode to be dispatched directly after the first compilation without repeated guard evaluation. + +### Level 3 -- PIECEWISE (Production Default) + +The most advanced level. When `Config.__post_init__` detects `level == PIECEWISE`, it: + +1. Calls `CompilationConfig.set_splitting_ops_for_v1()` to configure the splitting operations (default: `["aiter.unified_attention_with_output", "aiter.mla_attention"]`). +2. Calls `Config._set_cudagraph_sizes()` to compute the graph batch sizes. +3. Sets `cudagraph_mode = CUDAGraphMode.PIECEWISE`. +4. Calls `CompilationConfig.init_with_cudagraph_sizes()` to finalize compile sizes. + +The `VllmBackend` is then used as the `torch.compile` backend. It splits the model graph into subgraphs at the splitting operations and compiles each subgraph independently via `PiecewiseBackend`. + +--- + +## 2. CUDA Graph Modes + +The `CUDAGraphMode` enum in `atom/config.py` controls how CUDA graphs are captured and replayed. CUDA graphs record a sequence of GPU operations and replay them with minimal CPU overhead, which is critical for low-latency decode steps. + +### NONE (value: 0) + +No CUDA graph capture or replay. Every forward pass launches kernels individually. This mode is used during profiling, warmup, or when CUDA graphs are not supported. + +### PIECEWISE (value: 1) + +The default mode for level 3 compilation. CUDA graphs are captured per subgraph (one for each piecewise-compiled region). Attention operations, which are split out by `splitting_ops`, run outside CUDA graphs because they may need dynamic metadata that changes between steps. + +The `CUDAGraphWrapper` class wraps each subgraph with `runtime_mode=CUDAGraphMode.PIECEWISE` for capture and replay. + +### FULL (value: 2) + +The entire model forward pass is captured as a single CUDA graph. This is suitable for small models or workloads with small, uniform batch sizes. Not all attention backends support full CUDA graph capture. + +### FULL_DECODE_ONLY (value: (FULL, NONE)) + +A tuple mode that applies different strategies to different batch types: +- **Decode batches**: Captured with full CUDA graphs. +- **Mixed prefill-decode batches**: Run without CUDA graphs. + +This is useful for prefill/decode disaggregated (P/D) setups where decode latency matters more than prefill performance. + +### FULL_AND_PIECEWISE (value: (FULL, PIECEWISE)) + +A tuple mode combining both strategies: +- **Decode batches**: Captured with full CUDA graphs. +- **Prefill and mixed batches**: Captured with piecewise CUDA graphs. + +This is described in the code as "the most performant mode for most models." + +### Helper Methods + +The `CUDAGraphMode` enum provides several helper methods for runtime dispatch: + +| Method | Returns | Purpose | +|--------|---------|---------| +| `decode_mode()` | `CUDAGraphMode` | Returns the mode to use for decode batches. For tuple modes, returns the first element. | +| `mixed_mode()` | `CUDAGraphMode` | Returns the mode to use for mixed batches. For tuple modes, returns the second element. | +| `separate_routine()` | `bool` | Returns `True` if the mode is a tuple (different strategies for decode vs. mixed). | +| `has_full_cudagraphs()` | `bool` | Returns `True` if any part of the mode uses `FULL` capture. | +| `requires_piecewise_compilation()` | `bool` | Returns `True` if either decode or mixed mode uses `PIECEWISE`. | +| `max_cudagraph_mode()` | `CUDAGraphMode` | Returns the highest-valued mode across both decode and mixed modes. | + +--- + +## 3. CUDA Graph Capture + +CUDA graph capture is handled by `ModelRunner.capture_cudagraph()` in `atom/model_engine/model_runner.py`. This method is called at startup (under `@torch.inference_mode()`) to pre-capture graphs for a set of batch sizes. + +### Capture Flow + +``` +capture_cudagraph() + | + +-- Determine graph_bs list + | |-- If cudagraph_capture_sizes is set: use directly + | |-- If cuda_graph_sizes has 1 value N: [1, 2, 4, 8, 16, 32, ..., N] + | +-- If cuda_graph_sizes has >1 values: use the provided list + | + +-- Sort graph_bs in descending order (largest batch first) + | + +-- Assert max batch size <= max_num_seqs + | + +-- Initialize graph storage: self.graphs = dict() + | + +-- For each batch size bs (with progress bar on rank 0): + | | + | +-- Compute max_q_len (= mtp_k + 1 if MTP drafter, else 1) + | +-- Compute num_tokens = bs * max_q_len + | +-- Prepare cu_seqlens_q, positions + | +-- Build attn_metadata and context via attn_metadata_builder + | +-- Handle DP padding via get_dp_padding() + | +-- Set forward context (set_forward_context) + | +-- Warmup run: model(input_ids[:num_tokens], positions[:num_tokens]) + | +-- Capture: torch.cuda.graph(graph, self.graph_pool, stream=gc.stream) + | +-- Share graph_pool across captures (set on first capture) + | +-- Store: self.graphs[(bs, max_q_len)] = graph + | +-- torch.cuda.synchronize() + | + +-- Sort graph_bs back to ascending order + +-- Return (elapsed_time, graph_bs) +``` + +### Graph Keying + +Each captured graph is stored in a dictionary keyed by a `(graph_bs, max_q_len)` tuple: + +```python +self.graphs: dict[tuple[int, int], torch.cuda.CUDAGraph] = dict() +``` + +- `graph_bs`: The padded batch size used during capture. +- `max_q_len`: The maximum query length per sequence. For standard decode, this is `1`. For MTP (Multi-Token Prediction) speculative decoding, this is `mtp_k + 1`. + +### Graph Pool Sharing + +The first captured graph creates a CUDA memory pool via `graph.pool()`. All subsequent captures share this pool through the `self.graph_pool` parameter, enabling memory reuse across different batch sizes. + +```python +if self.graph_pool is None: + self.graph_pool = graph.pool() +``` + +### Default Capture Sizes + +When `cuda_graph_sizes` has a single value (e.g., `[512]`, the default), the capture sizes follow this pattern: + +```python +[1, 2, 4, 8] + [i for i in range(16, cuda_graph_sizes[0] + 1, 16)] +# Example with default 512: +# [1, 2, 4, 8, 16, 32, 48, 64, ..., 496, 512] +``` + +### Graph Replay in run_model() + +During inference, `ModelRunner.run_model()` decides whether to use eager execution or graph replay: + +```python +def run_model(self, input_ids): + forward_context = get_forward_context() + context = forward_context.context + bs = context.batch_size + is_prefill = context.is_prefill + positions = context.positions + + if is_prefill or self.enforce_eager or bs > self.graph_bs[-1]: + # Eager path: prefills, enforce_eager mode, or oversized batches + hidden_states = self.model(input_ids, positions) + else: + # Graph replay path: decode batches within captured range + graph_bs = context.graph_bs + max_q_len = forward_context.attn_metadata.max_seqlen_q + graph_key = (graph_bs, max_q_len) + self.graphs[graph_key].replay() + num_tokens = context.batch_size * max_q_len + hidden_states = self.forward_vars["outputs"][:num_tokens] + + return self.model.compute_logits(hidden_states), hidden_states +``` + +Key decisions: +- **Prefill**: Always eager (variable sequence lengths make CUDA graphs impractical). +- **Decode with bs <= max captured size**: Replay the pre-captured graph. +- **Decode with bs > max captured size**: Fall back to eager execution. +- **enforce_eager=True**: Always eager, regardless of batch size. + +--- + +## 4. Piecewise Compilation + +Piecewise compilation splits the model's computation graph at specified operations and compiles each subgraph independently. This enables CUDA graph capture for the compilable parts while leaving incompatible operations (primarily attention) to run eagerly. + +### Splitting Operations + +The `splitting_ops` field in `CompilationConfig` defines which operations split the graph. When `set_splitting_ops_for_v1()` is called (automatically at level 3), the default splitting ops are: + +```python +["aiter.unified_attention_with_output", "aiter.mla_attention"] +``` + +These attention operations are split out because: +1. They require dynamic metadata (sequence lengths, block tables) that changes per step. +2. Some attention backends are not compatible with CUDA graph capture. +3. Attention kernels are already highly optimized, so Inductor compilation provides minimal additional benefit. + +### Compilation Pipeline + +The `VllmBackend.__call__` method orchestrates the piecewise compilation: + +1. **Graph splitting**: `split_graph()` divides the traced model graph at the splitting operations into a sequence of `SplitItem` objects, each containing a subgraph. + +2. **Submodule identification**: Subgraphs that are *not* splitting operations are identified as candidates for compilation. + +3. **Dynamic-shape compilation**: `PiecewiseCompileInterpreter` runs the split graph with fake inputs and compiles each non-splitting subgraph via `CompilerManager.compile()` for a general (dynamic) shape. + +4. **Backend creation**: For each compiled subgraph, a `PiecewiseBackend` instance is created. It holds: + - `compiled_graph_for_general_shape`: The Inductor-compiled graph for dynamic shapes. + - `concrete_size_entries`: A dictionary mapping specific runtime shapes to `ConcreteSizeEntry` objects for shape-specialized compilation. + +5. **Runtime dispatch**: When `PiecewiseBackend.__call__` is invoked: + - On the first run, it uses the general-shape compiled graph. + - For subsequent runs, if the runtime shape is in `compile_sizes`, it lazily compiles a shape-specialized version via `CompilerManager.compile()` and caches it. + - For shapes not in `compile_sizes`, it falls back to the general-shape compiled graph. + +### Cache Management + +The `CompilerManager` caches compiled graphs using a key of `(runtime_shape, graph_index, backend_name)`. The cache is stored in a Python file (`vllm_compile_cache.py`) at the local cache directory (`~/.cache/atom/torch_compile_cache//rank_//`). + +On subsequent runs with the same model and configuration, compiled graphs are loaded from the cache, bypassing Inductor compilation entirely. + +--- + +## 5. Forward Context & Stateless Dispatch + +The `ForwardContext` dataclass in `atom/utils/forward_context.py` provides a module-level global mechanism for passing metadata to layers during the forward pass. This is critical for CUDA graphs because captured graphs cannot accept new arguments -- all dynamic metadata must be accessible through a side channel. + +### ForwardContext Fields + +| Field | Type | Purpose | +|-------|------|---------| +| `no_compile_layers` | `dict[int, Any]` | Layers that should skip compilation (from `static_forward_context`) | +| `attn_metadata` | `AttentionMetaData` or `dict` | Attention-specific metadata (sequence lengths, block tables, etc.) | +| `kv_cache_data` | `dict[str, KVCacheTensor]` | KV cache tensors for each layer | +| `context` | `Context` | Basic forward pass context (positions, is_prefill, batch_size, graph_bs) | +| `dp_metadata` | `DPMetadata` | Data-parallel metadata (token counts across DP ranks) | +| `spec_decode_metadata` | `SpecDecodeMetadata` | Speculative decoding metadata (draft tokens, logits indices) | + +### Lifecycle + +The forward context follows a set-use-reset lifecycle: + +1. **Set**: Before each forward pass, `set_forward_context()` is called with attention metadata, the ATOM config, a `Context` object, and optional DP/speculative decoding metadata. + +2. **Access**: During the forward pass, any layer can call `get_forward_context()` to retrieve the current metadata without needing it passed as a function argument. This is used by both eager execution and CUDA graph replay paths. + +3. **Reset**: After the forward pass, `reset_forward_context()` replaces the global context with an empty `ForwardContext()`. + +### Context Dataclass + +The `Context` object carries the most frequently accessed per-step state: + +```python +@dataclass +class Context: + positions: torch.Tensor # Token position IDs + is_prefill: bool = False # Whether this is a prefill step + batch_size: int = 0 # Number of sequences in the batch + graph_bs: int = 0 # Padded batch size for graph lookup + is_draft: bool = False # Whether this is a draft model forward +``` + +The `graph_bs` field is particularly important for CUDA graph dispatch: it holds the padded batch size that maps to a pre-captured graph key. + +### Integration with CUDA Graphs + +For ModelRunner's direct CUDA graph path (non-piecewise), the forward context is set before `run_model()` via `set_forward_context()`, and `run_model()` reads `context.graph_bs` and `attn_metadata.max_seqlen_q` to look up the correct pre-captured graph. + +For the piecewise path, `CUDAGraphWrapper` (in `atom/utils/cuda_graph.py`) expects `batch_descriptor` and `cudagraph_runtime_mode` fields on the forward context to decide whether to capture, replay, or run eagerly: + +```python +forward_context = get_forward_context() +batch_descriptor = forward_context.batch_descriptor +cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode +``` + +> **Note:** The piecewise `CUDAGraphWrapper` integration is under development. The `batch_descriptor` and `cudagraph_runtime_mode` fields are expected by `CUDAGraphWrapper.__call__()` but are not currently defined on the `ForwardContext` dataclass. The per-subgraph wrapping in `backends.py` is also currently commented out. The direct CUDA graph path in `ModelRunner` is the active production path. + +--- + +## 6. Compiler Backend + +### CompilerManager + +`CompilerManager` in `atom/utils/backends.py` manages the full compilation lifecycle: + +- **Initialization**: Creates a `CompilerInterface` via `make_compiler()`. Uses `InductorStandaloneAdaptor` for PyTorch 2.8+ or `InductorAdaptor` for earlier versions. +- **Caching**: Maintains a dictionary mapping `(runtime_shape, graph_index, backend_name)` to compiler-specific handles. Caches are serialized to a Python file using `pprint`. +- **Compile-or-load**: On each call to `compile()`, first attempts `load()` from the cache. On miss, delegates to the compiler and stores the result. + +### CompilerInterface + +`CompilerInterface` in `atom/utils/compiler_inferface.py` (note the typo in the filename) defines the abstract interface that all compiler backends must implement: + +| Method | Purpose | +|--------|---------| +| `initialize_cache(cache_dir, disable_cache, prefix)` | Set up cache directories for the compiler | +| `compute_hash(vllm_config)` | Generate a hash of compiler-specific state for cache invalidation | +| `compile(graph, example_inputs, compiler_config, runtime_shape, key)` | Compile a graph, returning `(compiled_callable, handle)` | +| `load(handle, graph, example_inputs, graph_index, runtime_shape)` | Load a previously compiled graph from the handle | + +### InductorAdaptor + +The default compiler for PyTorch < 2.8. Uses `torch._inductor.compile_fx.compile_fx` and monkey-patches several internal functions to: + +- Extract the compilation hash for caching. +- Provide a dummy shape environment (`AlwaysHitShapeEnv`) so Inductor cache lookups succeed outside of Dynamo's tracing context. +- Force caching of graphs that Inductor would normally refuse to cache. + +When `runtime_shape` is an integer (specific batch size), it enables `max_autotune` and `coordinate_descent_tuning` for Triton kernel parameter optimization. + +### InductorStandaloneAdaptor + +The preferred compiler for PyTorch 2.8+. Uses `torch._inductor.standalone_compile` which provides a cleaner interface without the monkey-patching required by `InductorAdaptor`. Compiled artifacts are saved to disk in "unpacked" format and can be loaded directly. + +### VllmBackend + +`VllmBackend` in `atom/utils/backends.py` serves as the `torch.compile` backend for level 3 (piecewise) compilation. When Dynamo calls it: + +1. Computes a cache directory hash from config, traced files, and compiler state. +2. Splits the graph at `splitting_ops` using `split_graph()`. +3. Runs `PiecewiseCompileInterpreter` to compile each non-splitting subgraph. +4. Saves the computation graph to `computation_graph.py` for debugging. +5. Returns the stitching graph module (`split_gm`) as the callable. + +If `cudagraph_copy_inputs` is `True`, it wraps the callable to copy input tensors into static buffers before each call, ensuring CUDA graph input address stability. + +### @support_torch_compile Decorator + +The `@support_torch_compile` decorator in `atom/utils/decorators.py` augments a model class to support `torch.compile`: + +1. **Class modification**: Adds `TorchCompileWrapperWithCustomDispatcher` as a base class and overrides `__init__` and `__call__`. + +2. **Dynamic shape marking**: On the first compilation, it inspects the `forward` method signature, identifies `torch.Tensor` arguments, and calls `torch._dynamo.mark_dynamic()` to mark their batch dimensions as dynamic. + +3. **Custom dispatch**: After the first compilation, if `use_custom_dispatcher` is True (levels >= 2), subsequent calls bypass Dynamo's guard mechanism and dispatch directly to the compiled bytecode via `dispatch_to_code(0)`. + +4. **Safety check**: The bytecode hook checks for `update` in the compiled code's `co_names`, raising an error if the model modifies `nn.Module` buffers during the forward pass (which would cause silent errors with CUDA graphs). + +### Custom Op Registration + +`direct_register_custom_op()` in `atom/utils/custom_register.py` registers custom operators with PyTorch's `torch.library` system: + +```python +direct_register_custom_op( + op_name="my_op", + op_func=my_kernel, + mutates_args=["output"], + fake_impl=my_fake_impl, +) +``` + +This registers the op under the `"aiter"` library namespace (e.g., `aiter.my_op`), making it visible to Dynamo's tracing. The `fake_impl` is used during tracing to compute output shapes without executing the real kernel. The `dispatch_key` defaults to `"CUDA"` for GPU operations. + +Registered custom ops can be used as `splitting_ops` in piecewise compilation (e.g., `"aiter.unified_attention_with_output"`). + +--- + +## 7. Configuration Options + +All compilation-related configuration fields from `CompilationConfig`: + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `level` | `int` | `0` | Compilation level (0-3). See Section 1. | +| `use_cudagraph` | `bool` | `True` | Whether CUDA graph capture is enabled. | +| `cudagraph_capture_sizes` | `Optional[list[int]]` | `None` | Explicit list of batch sizes to capture. Overrides `cuda_graph_sizes`. | +| `cuda_graph_sizes` | `list[int]` | `[512]` | Controls auto-generated capture sizes. 1 value = generate pattern; >1 values = use directly. | +| `cudagraph_mode` | `Optional[CUDAGraphMode]` | `None` | CUDA graph mode. Set to `PIECEWISE` automatically at level 3. | +| `splitting_ops` | `Optional[list[str]]` | `None` | Operations that split the graph for piecewise compilation. Auto-set at level 3. | +| `cudagraph_copy_inputs` | `bool` | `False` | Copy input tensors to static buffers for CUDA graph address stability. Only effective in `PIECEWISE` mode. | +| `use_inductor` | `bool` | `True` | Whether to use the Inductor compiler backend. | +| `compile_sizes` | `Optional[list[Union[int, str]]]` | `None` | Specific sizes to compile with Inductor. Supports `"cudagraph_capture_sizes"` string. | +| `inductor_compile_config` | `dict` | `{}` | Additional Inductor configuration (e.g., `max_autotune`). | +| `debug_dump_path` | `str` | `""` | Path to dump debug information (traced graphs, decompiled code). | +| `cache_dir` | `str` | `""` | Custom cache directory. Auto-generated if empty (`~/.cache/atom/torch_compile_cache//`). | + +Related fields on `Config`: + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `enforce_eager` | `bool` | `False` | Force eager execution, skip all compilation and CUDA graphs. | +| `graph_bs` | `Optional[list[int]]` | `None` | Final list of batch sizes for CUDA graph capture (computed from `CompilationConfig`). | +| `compilation_config` | `CompilationConfig` | `CompilationConfig()` | The compilation configuration dataclass. | + +--- + +## 8. Decision Tree + +Use this decision tree to select the right compilation level and CUDA graph mode for your workload: + +``` +Is the model supported by torch.compile? +| ++-- No --> Level 0 (NO_COMPILATION) +| enforce_eager=True +| ++-- Yes + | + +-- Debugging / profiling? + | | + | +-- Yes --> Level 0 (NO_COMPILATION) + | + +-- Quick compatibility check? + | | + | +-- Yes --> Level 1 (DYNAMO_AS_IS) + | + +-- Want Inductor optimization without CUDA graphs? + | | + | +-- Yes --> Level 2 (DYNAMO_ONCE) + | + +-- Production deployment + | + +-- Level 3 (PIECEWISE) [recommended] + | + +-- Standard serving --> cudagraph_mode=PIECEWISE (default) + | + +-- Small model / uniform batches --> cudagraph_mode=FULL + | + +-- P/D disaggregated (decode instance) --> cudagraph_mode=FULL_DECODE_ONLY + | + +-- Maximum performance --> cudagraph_mode=FULL_AND_PIECEWISE +``` + +### Common Configurations + +**Default production setup** (level 3, piecewise CUDA graphs): +```python +CompilationConfig(level=3) +# Automatically sets: +# splitting_ops = ["aiter.unified_attention_with_output", "aiter.mla_attention"] +# cudagraph_mode = CUDAGraphMode.PIECEWISE +# cuda_graph_sizes = [512] +# graph_bs = [1, 2, 4, 8, 16, 32, ..., 512] +``` + +**Custom capture sizes**: +```python +CompilationConfig(level=3, cudagraph_capture_sizes=[1, 2, 4, 8]) +``` + +**Debugging with full eager execution**: +```python +Config(model="...", enforce_eager=True) +# or +CompilationConfig(level=0) +``` + +**Inductor with debug dump**: +```python +CompilationConfig(level=3, debug_dump_path="/tmp/atom_debug") +# Dumps traced graphs and decompiled code to /tmp/atom_debug/rank_0/ +``` + +--- + +## Source Files + +| File | Description | +|------|-------------| +| `atom/config.py` | `CompilationLevel`, `CompilationConfig`, `CUDAGraphMode`, `Config.__post_init__` (compilation setup) | +| `atom/utils/cuda_graph.py` | `CUDAGraphEntry`, `CUDAGraphOptions`, `CUDAGraphWrapper`, `BatchDescriptor` | +| `atom/utils/backends.py` | `CompilerManager`, `VllmBackend`, `SplitItem`, `split_graph()`, `PiecewiseCompileInterpreter` | +| `atom/utils/forward_context.py` | `ForwardContext`, `Context`, `AttentionMetaData`, `DPMetadata`, `set_forward_context()`, `get_forward_context()` | +| `atom/utils/compiler_inferface.py` | `CompilerInterface`, `InductorAdaptor`, `InductorStandaloneAdaptor`, `AlwaysHitShapeEnv` | +| `atom/utils/cuda_piecewise_backend.py` | `PiecewiseBackend`, `ConcreteSizeEntry` | +| `atom/utils/decorators.py` | `@support_torch_compile`, `TorchCompileWrapperWithCustomDispatcher`, `start_monitoring_torch_compile` | +| `atom/utils/custom_register.py` | `direct_register_custom_op()`, `aiter_lib` (Library instance) | +| `atom/model_engine/model_runner.py` | `ModelRunner.capture_cudagraph()`, `ModelRunner.run_model()` | diff --git a/docs/configuration_guide.md b/docs/configuration_guide.md new file mode 100644 index 000000000..cae23a96e --- /dev/null +++ b/docs/configuration_guide.md @@ -0,0 +1,356 @@ +# ATOM Configuration Guide + +ATOM (AiTer Optimized Model) is AMD's lightweight LLM inference engine built on +[AITER](https://github.com/ROCm/aiter) kernels for ROCm/HIP GPUs. This guide +documents every configuration class, CLI flag, and environment variable that +controls ATOM's runtime behaviour. + +--- + +## Quick Reference + +| Config Class | Primary Purpose | +|---|---| +| `Config` | Master dataclass -- model path, memory, TP size, scheduler limits, KV cache, profiler, and references to all sub-configs | +| `CompilationConfig` | Compilation level (0-3), CUDA graph capture sizes, piecewise splitting ops, inductor settings | +| `CompilationLevel` | Integer constants for the four compilation levels | +| `CUDAGraphMode` | Enum controlling how CUDA graphs are captured (none / piecewise / full / hybrid) | +| `QuantizationConfig` | Quantization type, dtype, dynamic flag, method, excluded layers | +| `ParallelConfig` | Data-parallel size, rank, master IP/port | +| `SpeculativeConfig` | Speculative decoding method, draft model, number of speculative tokens | +| `KVCacheConfig` / `KVCacheTensor` | Per-layer KV cache tensor descriptors (k/v caches and scales) | +| `SamplingParams` | Temperature, max tokens, stop strings, ignore-EOS flag | +| `EngineArgs` | CLI argument parser that builds a `Config` for `LLMEngine` | + +--- + +## 1. Master Configuration (`Config`) + +Defined in `atom/config.py`. The root dataclass that the engine consumes. + +| Field | Type | Default | Description | +|---|---|---|---| +| `model` | `str` | *(required)* | HuggingFace model name or local path | +| `trust_remote_code` | `bool` | `False` | Trust remote code when loading the model from HuggingFace | +| `max_num_batched_tokens` | `int` | `16384` | Maximum number of tokens batched together per scheduler step | +| `scheduler_delay_factor` | `float` | `0.0` | Multiplicative delay (factor x previous prompt latency) before scheduling the next prompt | +| `max_num_seqs` | `int` | `512` | Maximum number of sequences batched together | +| `max_model_len` | `int \| None` | `None` | Maximum context length; defaults to `hf_config.max_position_embeddings` (capped by it when set) | +| `gpu_memory_utilization` | `float` | `0.9` | Fraction of GPU memory available for KV cache and weights (0.0 -- 1.0) | +| `tensor_parallel_size` | `int` | `1` | Number of tensor-parallel GPUs (1 -- 8) | +| `enforce_eager` | `bool` | `False` | Disable compilation and CUDA graphs; run in eager mode | +| `parallel_config` | `ParallelConfig` | `ParallelConfig()` | Data-parallel configuration (see Section 4) | +| `kv_cache_block_size` | `int` | `16` | Block size for paged KV cache; must be a multiple of 16 or exactly 1 | +| `num_kvcache_blocks` | `int` | `-1` | Number of KV cache blocks (`-1` = auto) | +| `kv_cache_dtype` | `str` | `"bf16"` | KV cache data type (`"bf16"` or `"fp8"`) | +| `enable_prefix_caching` | `bool` | `False` | Enable prefix caching to reuse KV blocks across requests sharing the same prefix | +| `port` | `int` | `8006` | Engine internal communication port | +| `torch_profiler_dir` | `str \| None` | `os.getenv("ATOM_TORCH_PROFILER_DIR", None)` | Directory for saving PyTorch profiler traces; creates the directory if it does not exist | +| `compilation_config` | `CompilationConfig` | `CompilationConfig()` | Compilation and CUDA graph settings (see Section 2) | +| `quant_config` | `QuantizationConfig` | `QuantizationConfig()` | Quantization settings; auto-detected from HuggingFace config during `__post_init__` (see Section 3) | +| `asyncio_mode` | `bool` | `False` | Enable asyncio-based engine loop | +| `load_dummy` | `bool` | `False` | Skip loading model weights (for benchmarking / testing) | +| `enable_expert_parallel` | `bool` | `False` | Enable Expert Parallelism for MoE models | +| `master_addr` | `str` | `"127.0.0.1"` | Master address for distributed communication | +| `graph_bs` | `Optional[list[int]]` | `None` | Explicit list of batch sizes for CUDA graph capture; derived from `compilation_config` during init | +| `enable_dp_attention` | `bool` | `False` | Enable data-parallel attention | +| `torch_dtype` | `torch.dtype` | *(computed)* | Inferred from `hf_config.torch_dtype`; falls back to `torch.bfloat16` | +| `speculative_config` | `Optional[SpeculativeConfig]` | `None` | Speculative decoding configuration (see Section 5) | +| `bos_token_id` | `int` | `-1` | Beginning-of-sequence token ID (`-1` = use model default) | +| `eos_token_id` | `int` | `-1` | End-of-sequence token ID (`-1` = use model default) | +| `stop_token_ids` | `list[int]` | `[]` | Additional stop token IDs; populated from `GenerationConfig.eos_token_id` during init | + +**Auto-derived fields** (set in `__post_init__`, not user-supplied): + +| Field | Type | Description | +|---|---|---| +| `hf_config` | `PretrainedConfig` | Loaded automatically via `get_hf_config(model)` | +| `generation_config` | `GenerationConfig` | Loaded automatically via `get_generation_config(model)` | + +--- + +## 2. Compilation Configuration (`CompilationConfig`) + +Defined in `atom/config.py`. Controls torch.compile and CUDA graph behaviour. + +### 2.1 Compilation Levels (`CompilationLevel`) + +| Constant | Value | Description | +|---|---|---| +| `NO_COMPILATION` | `0` | No compilation -- pure eager execution | +| `DYNAMO_AS_IS` | `1` | Use torch.compile / TorchDynamo as-is | +| `DYNAMO_ONCE` | `2` | TorchDynamo with a single compilation pass | +| `PIECEWISE` | `3` | Piecewise compilation with CUDA graph capture (recommended for production) | + +### 2.2 `CompilationConfig` Fields + +| Field | Type | Default | Description | +|---|---|---|---| +| `level` | `int` | `0` | Compilation level (see table above); must be 0 -- 3 | +| `use_cudagraph` | `bool` | `True` | Whether to use CUDA graphs | +| `cudagraph_capture_sizes` | `Optional[list[int]]` | `None` | Explicit list of batch sizes for CUDA graph capture; overrides `cuda_graph_sizes` when set | +| `cuda_graph_sizes` | `list[int]` | `[]` (post-init: `[512]`) | CUDA graph sizing strategy: 1 value generates `[1,2,4,8] + range(16, N+1, 16)`; multiple values used as-is; empty defaults to `[512]` | +| `debug_dump_path` | `str` | `""` | Path to dump debug / compilation information | +| `cache_dir` | `str` | `""` | Directory for compilation caches | +| `use_inductor` | `bool` | `True` | Enable TorchInductor backend | +| `cudagraph_mode` | `Optional[CUDAGraphMode]` | `None` | CUDA graph capture mode (see below); set to `PIECEWISE` automatically at level 3 | +| `splitting_ops` | `Optional[list[str]]` | `None` | Ops that split the graph into sub-graphs for piecewise compilation; auto-populated at level 3 with `["aiter.unified_attention_with_output", "aiter.mla_attention"]` | +| `cudagraph_copy_inputs` | `bool` | `False` | Copy input tensors into internally managed buffers before CUDA graph replay; only effective in PIECEWISE mode | +| `compile_sizes` | `Optional[list[Union[int, str]]]` | `None` | Sizes to compile for inductor; accepts integers and the string `"cudagraph_capture_sizes"` | +| `inductor_compile_config` | `dict` | `{}` | Additional configuration passed to the inductor backend | + +### 2.3 CUDA Graph Mode (`CUDAGraphMode`) + +| Mode | Value | Description | +|---|---|---| +| `NONE` | `0` | No CUDA graph capture | +| `PIECEWISE` | `1` | Piecewise CUDA graphs -- attention ops stay outside the graph for flexibility (default at level 3) | +| `FULL` | `2` | Full CUDA graph capture for all batches; best for small models / short prompts | +| `FULL_DECODE_ONLY` | `(FULL, NONE)` | Full CUDA graphs for decode batches only; mixed prefill-decode runs without graphs (useful in P/D setups) | +| `FULL_AND_PIECEWISE` | `(FULL, PIECEWISE)` | Full graphs for decode, piecewise for prefill/mixed -- most performant mode for most models | + +Helper methods on `CUDAGraphMode`: + +- `decode_mode()` -- returns the mode used for pure decode batches. +- `mixed_mode()` -- returns the mode used for mixed prefill-decode batches. +- `requires_piecewise_compilation()` -- whether the mode needs piecewise compilation. +- `has_full_cudagraphs()` -- whether the mode includes full CUDA graph capture. +- `separate_routine()` -- whether decode and mixed batches use different routines. + +--- + +## 3. Quantization Configuration (`QuantizationConfig`) + +Defined in `atom/config.py`. Extends `dict` so fields are stored and accessed as +dictionary keys (e.g., `config["quant_type"]`). + +### 3.1 `QuantizationConfig` Fields + +| Field | Type | Default | Description | +|---|---|---|---| +| `quant_type` | `QuantType` | `QuantType.No` | Quantization granularity (see below) | +| `quant_dtype` | `torch.dtype` | `torch.bfloat16` | Data type for quantized weights | +| `is_dynamic` | `bool` | `True` | Use dynamic quantization (scales computed at runtime) | +| `quant_name` | `str` | `""` | Human-readable name for the quantization scheme | +| `quant_method` | `Optional[str]` | `None` | Quantization method from HuggingFace config (e.g., `"compressed-tensors"`, `"quark"`) | +| `exclude_layers` | `Optional[list[str]]` | `[]` | Layer names excluded from quantization | + +### 3.2 `QuantType` Values (from AITER) + +| Value | Description | +|---|---| +| `QuantType.No` | No quantization | +| `QuantType.per_Token` | Per-token / per-channel quantization | +| `QuantType.per_1x128` | Block quantization with group size 128 | +| `QuantType.per_1x32` | Block quantization with group size 32 | +| `QuantType.per_128x128` | Large 2D block quantization (remapped to `per_1x128` in MoE kernels) | +| `QuantType.per_Tensor` | Per-tensor quantization | + +### 3.3 Supported Quantization Dtypes + +| Dtype | AITER Key | Notes | +|---|---|---| +| FP8 (E4M3) | `"fp8"` | 8-bit floating point | +| MXFP4 | `"fp4x2"` | Microscaling FP4; forces `QuantType.per_1x32` | +| INT8 | `"i8"` | 8-bit integer | +| INT4 | `"i4x2"` | 4-bit integer (packed) | + +### 3.4 Auto-Detection from HuggingFace (`get_quant_config`) + +During `Config.__post_init__`, ATOM reads `hf_config.quantization_config` and +automatically determines `quant_type`, `quant_dtype`, and `is_dynamic`: + +1. If `quantization_config` is absent, returns `QuantType.No` with `torch_dtype`. +2. If `quant_method == "compressed-tensors"` or channel quantization is detected, sets `per_Token`. +3. If `weight_block_size` or `group_size` is found: group size 128 maps to `per_1x128`, group size 32 maps to `per_1x32`. +4. Otherwise falls back to `per_Tensor`. +5. The dtype is parsed from fields like `dtype`, `weight_dtype`, or `quant_method` looking for `fp8`, `fp4`, `mxfp4`, `int8`, `int4`, or `num_bits`. +6. If `activation_scheme` is `"static"`, `is_dynamic` is set to `False`. +7. Excluded layers are read from the `"ignore"` key (compressed-tensors) or `"exclude"` key (quark). + +--- + +## 4. Parallel Configuration (`ParallelConfig`) + +Defined in `atom/config.py`. Controls data parallelism. Environment variables +(Section 8) override defaults when set. + +| Field | Type | Default | Description | +|---|---|---|---| +| `data_parallel_size` | `int` | `1` | Number of data-parallel groups; overridden by `ATOM_DP_SIZE` env var | +| `data_parallel_size_local` | `int` | `1` | Number of local data-parallel groups | +| `data_parallel_rank` | `int` | `0` | Rank within the data-parallel group; overridden by `ATOM_DP_RANK` | +| `data_parallel_rank_local` | `Optional[int]` | `None` | Local rank within the data-parallel group (SPMD mode); overridden by `ATOM_DP_RANK_LOCAL` | +| `data_parallel_master_port` | `int` | `29500` | Port used by the data-parallel master for process group initialization | +| `data_parallel_base_port` | `int` | `get_open_port()` | Base port for data-parallel communication (dynamically assigned) | +| `data_parallel_master_ip` | `str` | `"127.0.0.1"` | IP address of the data-parallel master | + +**Computed property:** + +- `world_size` -- set during init, equals TP x PP. +- `world_size_across_dp` -- `world_size * data_parallel_size`. + +--- + +## 5. Speculative Decoding Configuration (`SpeculativeConfig`) + +Defined in `atom/config.py`. Currently only the Multi-Token Prediction (MTP) +method with `num_speculative_tokens=1` is supported. + +| Field | Type | Default | Description | +|---|---|---|---| +| `method` | `Optional[str]` | `""` | Speculative decoding method; currently only `"mtp"` is accepted | +| `model` | `Optional[str]` | `None` | Draft model name or path (typically the same as the target model for MTP) | +| `num_speculative_tokens` | `Optional[int]` | `None` | Number of speculative tokens per iteration; **must be `1`** | +| `draft_model_hf_config` | `Optional[PretrainedConfig]` | `None` | HuggingFace config for the draft model; auto-loaded from `model` when `None` | + +**Post-init behaviour:** + +- Loads `draft_model_hf_config` from `model` if not provided. +- For DeepSeek V3 / MTP models: overrides `model_type` to `"deepseek_mtp"`, sets `n_predict=1` and `num_nextn_predict_layers=1`, and switches architectures to `["DeepSeekMTPModel"]`. +- `Config.__post_init__` raises `ValueError` if `num_speculative_tokens != 1`. + +--- + +## 6. Sampling Parameters (`SamplingParams`) + +Defined in `atom/sampling_params.py`. Passed per-request to control generation. + +| Field | Type | Default | Description | +|---|---|---|---| +| `temperature` | `float` | `1.0` | Sampling temperature; lower values make output more deterministic | +| `max_tokens` | `int` | `64` | Maximum number of tokens to generate | +| `ignore_eos` | `bool` | `False` | Continue generating past the EOS token | +| `stop_strings` | `Optional[list[str]]` | `None` | List of strings that trigger generation to stop | + +--- + +## 7. CLI Arguments (`EngineArgs`) + +Defined in `atom/model_engine/arg_utils.py`. The `EngineArgs` dataclass exposes +all flags via `add_cli_args()` and converts them into a `Config` via +`create_engine()`. + +| Flag | Short | Type | Default | Description | +|---|---|---|---|---| +| `--model` | | `str` | `"Qwen/Qwen3-0.6B"` | Model name or path | +| `--trust-remote-code` | | flag | `False` | Trust remote code when loading model | +| `--tensor-parallel-size` | `-tp` | `int` | `1` | Tensor parallel size | +| `--data-parallel-size` | `-dp` | `int` | `1` | Data parallel size | +| `--enforce-eager` | | flag | `False` | Enforce eager mode execution | +| `--enable_prefix_caching` | | flag | `False` | Enable prefix caching | +| `--port` | | `int` | `8006` | Engine internal port | +| `--kv_cache_dtype` | | `str` | `"bf16"` | KV cache dtype; choices: `bf16`, `fp8` | +| `--block-size` | | `int` | `16` | KV cache block size (maps to `kv_cache_block_size`) | +| `--max-model-len` | | `int` | `None` | Maximum model context length; defaults to `hf_config.max_position_embeddings` | +| `--cudagraph-capture-sizes` | | `str` | `"[1,2,4,8,16,32,48,64,128,256]"` | CUDA graph capture sizes as a Python list string | +| `--level` | | `int` | `3` | Compilation level (0 -- 3) | +| `--load_dummy` | | flag | `False` | Skip loading model weights | +| `--enable-expert-parallel` | | flag | `False` | Enable Expert Parallelism (EP MoE) | +| `--torch-profiler-dir` | | `str` | `None` | Directory for torch profiler traces | +| `--enable-dp-attention` | | flag | `False` | Enable DP attention | +| `--method` | | `str` | `None` | Speculative method; choices: `mtp` | +| `--num-speculative-tokens` | | `int` | `1` | Number of speculative tokens per iteration | +| `--max-num-batched-tokens` | | `int` | `16384` | Maximum number of tokens to batch in the async engine | +| `--max-num-seqs` | | `int` | `512` | Maximum number of sequences to batch together | +| `--gpu-memory-utilization` | | `float` | `0.9` | Fraction of GPU memory to use (0.0 -- 1.0) | +| `--scheduler-delay-factor` | | `float` | `0.0` | Delay factor multiplied by previous prompt latency before scheduling next prompt | + +**Example:** + +```bash +python -m atom.entrypoint \ + --model deepseek-ai/DeepSeek-R1 \ + --tensor-parallel-size 8 \ + --level 3 \ + --cudagraph-capture-sizes "[1,2,4,8,16,32,64,128,256]" \ + --kv_cache_dtype fp8 \ + --gpu-memory-utilization 0.92 \ + --max-num-seqs 256 +``` + +--- + +## 8. Environment Variables + +### 8.1 Variables Registered in `atom/utils/envs.py` + +All variables use lazy evaluation. Boolean variables treat `"1"` as `True` and +anything else (including unset) as `False`, unless noted otherwise. + +| Variable | Type | Default | Description | +|---|---|---|---| +| `ATOM_DP_RANK` | `int` | `0` | Data-parallel rank of this process | +| `ATOM_DP_RANK_LOCAL` | `int` | `0` | Local data-parallel rank (for SPMD mode) | +| `ATOM_DP_SIZE` | `int` | `1` | Total number of data-parallel groups | +| `ATOM_DP_MASTER_IP` | `str` | `"127.0.0.1"` | IP address of the data-parallel master | +| `ATOM_DP_MASTER_PORT` | `int` | `29500` | Port of the data-parallel master | +| `ATOM_ENFORCE_EAGER` | `bool` | `False` | Force eager mode globally (also set programmatically by `set_current_atom_config`) | +| `ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION` | `bool` | `False` | Enable QK-norm + RoPE + cache + quant fusion; enable for Qwen3-MoE models | +| `ATOM_USE_TRITON_GEMM` | `bool` | `False` | Use Triton-based GEMM kernels instead of default backends | +| `ATOM_USE_TRITON_MXFP4_BMM` | `bool` | `False` | Use Triton-based MXFP4 batched matrix multiply | +| `ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION` | `bool` | `True` | Enable fused input RMSNorm + quantization for DeepSeek models | +| `ATOM_ENABLE_DS_QKNORM_QUANT_FUSION` | `bool` | `True` | Enable fused QK-norm + quantization for DeepSeek models | +| `ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION` | `bool` | `True` | Enable fused all-reduce + RMSNorm kernel | +| `ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT` | `bool` | `True` | Enable AITER Triton fused RMSNorm + quantization for LLaMA models | +| `ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT` | `bool` | `True` | Enable AITER Triton fused SiLU + multiply + quantization for LLaMA models | + +### 8.2 Additional Environment Variables (Used Outside `envs.py`) + +| Variable | Type | Default | Where Used | Description | +|---|---|---|---|---| +| `ATOM_TORCH_PROFILER_DIR` | `str` | `None` | `atom/config.py` (`Config.torch_profiler_dir`) | Directory for PyTorch profiler output; sets the default for `Config.torch_profiler_dir` | +| `ATOM_PROFILER_MORE` | `str` | `"0"` | `atom/model_engine/model_runner.py` | Set to `"1"` to enable detailed profiling (`record_shapes`, `with_stack`, `profile_memory`) | +| `HF_TOKEN` | `str` | `None` | `atom/config.py` (`get_hf_config`) | HuggingFace authentication token for gated model downloads | + +--- + +## 9. Decision Tree -- Choosing a Compilation Level + +``` +Start + | + v +Is this a debugging / development run? + |-- Yes --> Level 0 (NO_COMPILATION) or --enforce-eager + | + v +Do you need torch.compile but no graph splitting? + |-- Yes, one-shot compile --> Level 2 (DYNAMO_ONCE) + |-- Yes, keep Dynamo default --> Level 1 (DYNAMO_AS_IS) + | + v +Production inference on ROCm/HIP GPU? + |-- Yes --> Level 3 (PIECEWISE) [default in EngineArgs] + - Auto-sets CUDAGraphMode.PIECEWISE + - Auto-populates splitting_ops for attention ops + - Pair with --cudagraph-capture-sizes for your batch profile + | + v +Need maximum decode throughput? + |-- Yes --> Level 3 + set cudagraph_mode to FULL_AND_PIECEWISE + (full graphs for decode, piecewise for prefill) +``` + +**Rules of thumb:** + +- **Level 3** is the default for `EngineArgs` and is recommended for most + production workloads. +- **Level 0** / `--enforce-eager` is useful for debugging, profiling, or when + CUDA graphs are incompatible with your model. +- Match `--cudagraph-capture-sizes` to your expected batch sizes for optimal + memory usage and launch latency. +- When using `--enable-dp-attention` or Expert Parallelism (`--enable-expert-parallel`), + level 3 is still recommended. + +--- + +## Source Files + +| File | Description | +|---|---| +| `atom/config.py` | `Config`, `CompilationConfig`, `CompilationLevel`, `CUDAGraphMode`, `QuantizationConfig`, `ParallelConfig`, `SpeculativeConfig`, `KVCacheTensor`, `KVCacheConfig`, `get_quant_config`, `get_hf_config` | +| `atom/utils/envs.py` | All `ATOM_*` environment variable definitions with lazy evaluation | +| `atom/model_engine/arg_utils.py` | `EngineArgs` dataclass and CLI argument parser | +| `atom/sampling_params.py` | `SamplingParams` dataclass | +| `atom/model_engine/model_runner.py` | Uses `ATOM_PROFILER_MORE` and `ATOM_TORCH_PROFILER_DIR` for profiling | diff --git a/docs/distributed_guide.md b/docs/distributed_guide.md new file mode 100644 index 000000000..3084a5841 --- /dev/null +++ b/docs/distributed_guide.md @@ -0,0 +1,454 @@ +# ATOM Distributed Inference Guide + +ATOM (AiTer Optimized Model) supports three parallelism strategies for distributed LLM inference on AMD ROCm/HIP GPUs: Tensor Parallelism (TP), Data Parallelism (DP), and Expert Parallelism (EP). These can be combined to scale across multiple GPUs for large model serving. + +## Quick Reference + +| Parallelism | CLI Flag | Purpose | Communication | +|-------------|----------|---------|---------------| +| Tensor Parallel (TP) | `-tp N` / `--tensor-parallel-size N` | Shard weights across GPUs | NCCL AllReduce | +| Data Parallel (DP) | `-dp N` / `--data-parallel-size N` | Replicate model, split requests | Gloo AllReduce (CPU) | +| Expert Parallel (EP) | `--enable-expert-parallel` | Distribute MoE experts across GPUs | MORI All-to-All | +| DP Attention | `--enable-dp-attention` | Flatten DP into TP for MoE layers | NCCL AllGather/ReduceScatter | + +**Common configurations:** + +| Model Type | Configuration | Example | +|-----------|---------------|---------| +| Dense (Llama, Qwen3) | TP only | `-tp 8` | +| MoE (Qwen3-235B) | TP + EP | `-tp 8 --enable-expert-parallel` | +| MoE throughput scaling | TP + DP + EP | `-tp 4 -dp 2 --enable-expert-parallel` | +| Dense throughput scaling | TP + DP | `-tp 4 -dp 2` | + +--- + +## 1. Tensor Parallelism (TP) + +Tensor Parallelism shards model weights across GPUs so each GPU holds a slice of every layer. ATOM uses AITER's `init_dist_env()` to initialize NCCL process groups. + +### Weight Sharding + +ATOM provides parallel linear layer classes in `atom/model_ops/linear.py`: + +- **`ColumnParallelLinear`** -- splits the output dimension (dim 0) across TP ranks. Each GPU computes a shard of the output independently. +- **`RowParallelLinear`** -- splits the input dimension (dim 1) across TP ranks. After the local matmul, an AllReduce across the TP group aggregates partial results. +- **`QKVParallelLinear`** -- extends `ColumnParallelLinear` for attention Q/K/V projections. Partitions heads across TP ranks, replicating KV heads when `num_kv_heads < tp_size`. +- **`MergedColumnParallelLinear`** -- merges multiple column-parallel outputs (e.g., gate and up projections) into a single weight tensor, sharded along dim 0. +- **`ReplicatedLinear`** -- no sharding; weight is replicated on every rank. + +### Process Group Initialization + +In `ModelRunner.__init__()`, the distributed environment is set up via AITER: + +```python +from aiter import init_dist_env +from aiter.dist.parallel_state import get_tp_group, get_dp_group, get_pp_group + +init_dist_env( + config.tensor_parallel_size, + rankID=rank, + backend="nccl", + distributed_init_method=distributed_init_method, + data_parallel_size=config.parallel_config.data_parallel_size, + data_parallel_rank=config.parallel_config.data_parallel_rank, +) +``` + +After initialization, `get_tp_group()`, `get_dp_group()`, and `get_pp_group()` provide the respective process groups for collective operations. + +### AllReduce + +The AllReduce happens inside `LinearBase.forward()` when `tp_dim == 1` (row-parallel): + +```python +if self.tp_dim == 1 and self.tp_size > 1 and self.reduce_results: + y = get_tp_group().all_reduce(y, ca_fp8_quant=False) +``` + +### Configuration + +- `Config.tensor_parallel_size` (int, default `1`): Number of TP ranks. Must satisfy `1 <= tensor_parallel_size <= 8`. +- CLI: `--tensor-parallel-size N` or `-tp N` + +--- + +## 2. Data Parallelism (DP) + +Data Parallelism runs multiple independent engine replicas, each handling a subset of incoming requests. DP is coordinated at the scheduling level rather than the model level -- each DP rank has its own `EngineCore`, scheduler, and model runner. + +### Architecture + +When `data_parallel_size > 1`, `EngineCore.run_engine()` instantiates a `DPEngineCoreProc` instead of a plain `EngineCore`: + +```python +# atom/model_engine/engine_core.py +@staticmethod +def run_engine(config, input_address, output_address): + if config.parallel_config.data_parallel_size > 1: + engine = DPEngineCoreProc(config, input_address, output_address) + else: + engine = EngineCore(config, input_address, output_address) + engine.busy_loop() +``` + +### DP Process Group Initialization + +`DPEngineCoreProc._init_data_parallel()` creates a Gloo-based process group for CPU-side coordination: + +```python +def _init_data_parallel(self, config): + dp_rank = config.parallel_config.data_parallel_rank + dp_size = config.parallel_config.data_parallel_size + local_dp_rank = config.parallel_config.data_parallel_rank_local + + assert dp_size > 1 + assert local_dp_rank is not None + + self.dp_rank = dp_rank + self.dp_group = config.parallel_config.stateless_init_dp_group() +``` + +The `stateless_init_dp_group()` method (in `ParallelConfig`) calls `stateless_init_torch_distributed_process_group()` with the `gloo` backend, creating an isolated process group that does not interfere with the NCCL TP group. + +### Synchronized Busy Loop + +The DP busy loop overrides the base `EngineCore.busy_loop()` to synchronize state across DP ranks before each step. The `_sync_dp_state()` method packs four signals into an int64 tensor and performs a single `AllReduce(MAX)`: + +```python +# State synced: [is_prefill, num_tokens, has_unfinished, shutdown] +state_tensor = torch.tensor( + [ + 1 if local_is_prefill else 0, + local_num_tokens, + 1 if local_has_unfinished else 0, + 1 if local_shutdown else 0, + ], + dtype=torch.int64, device="cpu", +) +torch.distributed.all_reduce( + state_tensor, op=torch.distributed.ReduceOp.MAX, group=self.dp_group +) +``` + +This ensures: +- **All ranks agree on the batch type** (prefill vs. decode). Since MORI requires all DP ranks to execute the same phase, a rank that has no prefill work must run a dummy prefill when any other rank does prefill. +- **Graceful shutdown**: all ranks must agree before exiting. +- **Token count alignment**: the maximum token count across ranks is used for padding. + +### Dummy Batch Execution + +When a DP rank has no real work but other ranks do, it executes dummy batches to participate in collective operations: + +- **`_execute_dummy_batch()`** -- runs a 1-token decode dummy through the model, triggering AllReduce and MORI collectives so other ranks are not blocked. +- **`_execute_dummy_prefill(num_tokens)`** -- runs a dummy prefill with the same token count as the max across DP ranks, so that MORI dispatch/combine stays synchronized. + +### Device Assignment + +When DP is enabled on a single node, each DP rank uses a different set of GPUs. The device mapping in `ModelRunner.__init__()` is: + +```python +local_device_rank = dp_rank_local * config.tensor_parallel_size + rank +device = torch.device(f"cuda:{local_device_rank}") +``` + +For example, with DP=2 and TP=4: +- DP rank 0: GPUs 0, 1, 2, 3 +- DP rank 1: GPUs 4, 5, 6, 7 + +### DPMetadata + +The `DPMetadata` dataclass (in `atom/utils/forward_context.py`) tracks token distribution across DP ranks for padding and collective operations: + +```python +@dataclass +class DPMetadata: + max_tokens_across_dp_cpu: torch.Tensor # Max tokens on any DP rank + cu_tokens_across_dp_cpu: torch.Tensor # Cumulative token counts + max_tokens_across_dp: int # Pre-computed int for CUDA graph +``` + +`DPMetadata.num_tokens_across_dp()` gathers token counts via an AllReduce on the DP CPU group: + +```python +num_tokens_across_dp = [0] * dp_size +num_tokens_across_dp[dp_rank] = num_tokens +num_tokens_tensor = torch.tensor(num_tokens_across_dp, device="cpu", dtype=torch.int32) +dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) +``` + +### CoreManager (DP Orchestration) + +`CoreManager` (in `atom/model_engine/engine_core_mgr.py`) manages multiple DP engine processes: + +1. For each DP rank, it creates a `Config` copy with the appropriate `data_parallel_rank` and `data_parallel_rank_local`. +2. Launches each `EngineCore` in a separate `multiprocessing.Process`. +3. Uses ZMQ (ROUTER/DEALER) sockets for input distribution and ZMQ (PUSH/PULL) for output collection. +4. Distributes incoming requests across DP ranks via round-robin load balancing. +5. Waits for READY signals from all ranks before accepting requests. + +When `enable_dp_attention` is set, `CoreManager` flattens TP into DP: + +```python +if config.enable_dp_attention: + self.local_engine_count = config.tensor_parallel_size * config.parallel_config.data_parallel_size + config.parallel_config.data_parallel_size = self.local_engine_count + config.tensor_parallel_size = 1 +``` + +### Configuration + +- `ParallelConfig.data_parallel_size` (int, default `1`): Number of DP replicas. +- `ParallelConfig.data_parallel_rank` (int, default `0`): This rank's DP index. +- `ParallelConfig.data_parallel_rank_local` (int, default `None`): Local DP rank on this node. +- CLI: `--data-parallel-size N` or `-dp N` + +--- + +## 3. Expert Parallelism (EP) + +Expert Parallelism distributes MoE experts across GPUs so that each GPU owns a subset of experts. Tokens are routed to the correct GPU via all-to-all communication. + +### FusedMoEParallelConfig + +The `FusedMoEParallelConfig` dataclass (in `atom/model_ops/moe.py`) determines how MoE layers are parallelized: + +```python +@dataclass +class FusedMoEParallelConfig: + tp_size: int # Tensor parallel size (1 when EP is active) + dp_size: int # Data parallel size + ep_size: int # Expert parallel size + tp_rank: int + dp_rank: int + ep_rank: int + use_ep: bool # Whether EP is enabled + local_ep_size: int # Number of EP ranks on this node +``` + +Key properties: + +- **`use_all2all_kernels`**: returns `True` when `dp_size > 1 and use_ep and mori is available`. This activates the MORI all-to-all dispatch/combine kernels. +- When EP is enabled, `tp_size` is set to 1 and `ep_size = dp_size * tp_size` (the original TP size). Each device fully owns its assigned experts. + +The `FusedMoEParallelConfig.make()` static method constructs the config: + +```python +use_ep = dp_size_ * tp_size_ > 1 and parallel_config.enable_expert_parallel + +if enable_dp_attention: + # Flatten DP into TP: effective tp_size = dp_size * tp_size + tp_size, tp_rank = flatten_tp_across_dp(dp_rank) + +if use_ep: + ep_size = tp_size + ep_rank = tp_rank + # Each device owns experts fully -- no intra-expert tensor parallelism + return FusedMoEParallelConfig(tp_size=1, tp_rank=0, ep_size=ep_size, ...) +``` + +### Expert Distribution + +In `FusedMoE.__init__()`, when EP is active, the global experts are partitioned: + +```python +if self.use_ep: + self.local_num_experts, self.expert_map = determine_expert_map( + ep_size=self.ep_size, + ep_rank=self.ep_rank, + global_num_experts=self.global_num_experts, + ) +else: + self.local_num_experts = self.global_num_experts + self.expert_map = None +``` + +Each GPU only loads weights for its assigned experts, reducing per-GPU memory usage proportionally. + +### MORI Communication + +When `use_all2all_kernels` is `True`, the `MoriPrepareAndFinalize` class (in `atom/model_ops/fused_moe/mori_prepare_finalize.py`) handles token routing: + +**Dispatch phase** (`prepare()`): +1. Receives input activations, top-k weights, and top-k expert IDs. +2. Calls `self.mori_op.dispatch()` to send each token to the GPU that owns its selected expert. +3. Returns dispatched activations, scales, expert IDs, weights, and per-expert token counts. + +```python +(dispatch_a1, dispatch_weights, dispatch_scale, dispatch_ids, dispatch_recv_token_num +) = self.mori_op.dispatch(a1, topk_weights, scale, topk_ids, block_num, warp_per_block) +``` + +**Combine phase** (`finalize()`): +1. After expert computation, calls `self.mori_op.combine()` to route results back to the originating GPU. +2. Copies the combined result into the output tensor. + +```python +result = self.mori_op.combine(fused_expert_output, None, topk_ids, block_num, warp_per_block)[0] +output.copy_(result[:num_token]) +``` + +The block configuration adapts to the batch type: prefill uses `block_num=128, warp_per_block=16`, while decode uses `block_num=64, warp_per_block=4`. + +### Configuration + +- `Config.enable_expert_parallel` (bool, default `False`): Activates EP for MoE layers. +- `Config.enable_dp_attention` (bool, default `False`): Flattens DP ranks into the TP/EP dimension for MoE, while using per-rank attention for non-MoE layers. +- CLI: `--enable-expert-parallel`, `--enable-dp-attention` + +--- + +## 4. Environment Variables + +| Variable | Type | Default | Description | +|----------|------|---------|-------------| +| `ATOM_DP_RANK` | int | `0` | Data parallel rank index | +| `ATOM_DP_RANK_LOCAL` | int | `0` | Local data parallel rank on this node | +| `ATOM_DP_SIZE` | int | `1` | Total number of data parallel replicas | +| `ATOM_DP_MASTER_IP` | str | `127.0.0.1` | IP address for DP Gloo rendezvous | +| `ATOM_DP_MASTER_PORT` | int | `29500` | Port for DP Gloo rendezvous | +| `ATOM_ENFORCE_EAGER` | bool | `False` | Disable CUDA graphs (set automatically) | +| `ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION` | bool | `False` | Fuse QK-norm + RoPE + cache quant (for Qwen3-MoE) | + +Environment variables in `atom/utils/envs.py` are evaluated lazily via `__getattr__`. If `ATOM_DP_SIZE`, `ATOM_DP_RANK`, or `ATOM_DP_RANK_LOCAL` are set in the environment, they override programmatic `ParallelConfig` defaults in `ParallelConfig.__post_init__()`. + +**AITER environment variable (not in envs.py):** + +| Variable | Type | Default | Description | +|----------|------|---------|-------------| +| `AITER_QUICK_REDUCE_QUANTIZATION` | str | -- | Set to `INT4` to enable quantized AllReduce for prefill (read by AITER's AllReduce kernel) | + +--- + +## 5. Multi-GPU Deployment Examples + +### DeepSeek-R1 on 8 GPUs (TP8) + +From the project README -- a dense MLA model deployed with pure tensor parallelism: + +```bash +python -m atom.entrypoints.openai_server \ + --kv_cache_dtype fp8 \ + -tp 8 \ + --model deepseek-ai/DeepSeek-R1 +``` + +### Qwen3-235B-A22B on 8 GPUs (TP8 + EP) + +From `recipes/Qwen3-235b.md` -- a MoE model with 128 experts, deployed with tensor parallelism and expert parallelism: + +```bash +export AITER_QUICK_REDUCE_QUANTIZATION=INT4 +export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 + +python -m atom.entrypoints.openai_server \ + --model Qwen/Qwen3-235B-A22B-Instruct-2507-FP8 \ + -tp 8 \ + --kv_cache_dtype fp8 \ + --enable-expert-parallel \ + --max-model-len 16384 \ + --max-num-batched-tokens 20000 +``` + +Tips from the recipe: +- Use FP8 KV cache (`--kv_cache_dtype fp8`) for memory efficiency. +- Quick AllReduce with INT4 quantization reduces prefill TTFT. +- QK-norm + RoPE + cache quant fusion improves Qwen3-MoE kernel performance. + +### Kimi-K2-Thinking on 4 GPUs (TP4) + +From `recipes/Kimi-K2-Thinking.md` -- an MXFP4 MoE model: + +```bash +export HIP_VISIBLE_DEVICES=0,1,2,3 + +python -m atom.entrypoints.openai_server \ + --model amd/Kimi-K2-Thinking-MXFP4 \ + --trust-remote-code \ + -tp 4 \ + --kv_cache_dtype fp8 +``` + +--- + +## 6. Combined Parallelism Strategies + +### TP Only (Dense Models) + +For dense models like Llama and Qwen3 (non-MoE), use pure tensor parallelism: + +```bash +python -m atom.entrypoints.openai_server --model meta-llama/Meta-Llama-3-8B -tp 8 +``` + +All weights are sharded across GPUs. AllReduce collectives synchronize after each `RowParallelLinear`. + +### TP + EP (MoE Models) + +For MoE models, enable expert parallelism so each GPU holds a subset of experts: + +```bash +python -m atom.entrypoints.openai_server --model Qwen/Qwen3-235B-A22B-Instruct-2507-FP8 -tp 8 --enable-expert-parallel +``` + +Dense layers (attention, norms) remain tensor-parallel. MoE layers distribute experts across the `ep_size = tp_size` GPUs. MORI all-to-all routes tokens to the correct expert owner. + +### TP + DP (Dense Throughput) + +For throughput scaling with dense models, run multiple DP replicas: + +```bash +# On a node with 8 GPUs: 2 replicas, each using 4 GPUs +python -m atom.entrypoints.openai_server --model meta-llama/Meta-Llama-3-8B -tp 4 -dp 2 +``` + +Each DP replica independently processes a subset of requests. The `CoreManager` distributes requests via round-robin. Device mapping: +- DP rank 0, TP ranks 0-3 --> GPUs 0-3 +- DP rank 1, TP ranks 0-3 --> GPUs 4-7 + +Formula: `local_device_rank = dp_rank_local * tp_size + tp_rank` + +### TP + DP + EP (MoE Throughput) + +For MoE models with DP + EP, the expert parallel dimension spans all `tp_size * dp_size` devices: + +```bash +python -m atom.entrypoints.openai_server \ + --model Qwen/Qwen3-235B-A22B-Instruct-2507-FP8 \ + -tp 4 -dp 2 \ + --enable-expert-parallel +``` + +In this configuration: +- Dense layers: each DP replica has TP=4 for sharding. +- MoE layers: EP size = `dp_size * tp_size = 8`, spreading experts across all 8 GPUs. +- MORI all-to-all crosses DP boundaries to route tokens to the correct expert owner. + +### DP Attention Mode + +When `--enable-dp-attention` is set, `CoreManager` flattens the TP dimension into DP: + +```python +local_engine_count = tensor_parallel_size * data_parallel_size +data_parallel_size = local_engine_count +tensor_parallel_size = 1 +``` + +This means each GPU runs an independent attention computation (no TP AllReduce for attention), while MoE layers still use the full EP group across all GPUs. This can reduce communication overhead for attention-heavy workloads. + +--- + +## Source Files + +| File | Description | +|------|-------------| +| `atom/config.py` | `ParallelConfig`, `Config.tensor_parallel_size`, `enable_expert_parallel`, `enable_dp_attention` | +| `atom/utils/envs.py` | `ATOM_DP_*` environment variables (lazy evaluation) | +| `atom/model_engine/engine_core.py` | `EngineCore`, `DPEngineCoreProc` (DP busy loop, sync, dummy batches) | +| `atom/model_engine/engine_core_mgr.py` | `CoreManager` (multi-process DP orchestration, ZMQ IPC) | +| `atom/model_engine/model_runner.py` | `ModelRunner` (`init_dist_env`, device assignment, `DPMetadata` usage) | +| `atom/model_engine/arg_utils.py` | `EngineArgs` CLI argument definitions | +| `atom/utils/distributed/utils.py` | `stateless_init_torch_distributed_process_group()` (Gloo PG creation) | +| `atom/utils/forward_context.py` | `DPMetadata`, `ForwardContext` (per-step DP token metadata) | +| `atom/model_ops/linear.py` | `ColumnParallelLinear`, `RowParallelLinear`, `QKVParallelLinear`, `MergedColumnParallelLinear` | +| `atom/model_ops/moe.py` | `FusedMoE`, `FusedMoEParallelConfig` (EP configuration and expert distribution) | +| `atom/model_ops/fused_moe/mori_prepare_finalize.py` | `MoriPrepareAndFinalize` (MORI dispatch/combine for EP) | diff --git a/docs/model_ops_guide.md b/docs/model_ops_guide.md new file mode 100644 index 000000000..969082719 --- /dev/null +++ b/docs/model_ops_guide.md @@ -0,0 +1,542 @@ +# ATOM Model Operations Guide + +ATOM (AiTer Optimized Model) wraps AITER kernels with model-level abstractions for LLM inference on AMD ROCm/HIP GPUs. This guide documents every operator class in `atom/model_ops/`, their AITER kernel mappings, quantization paths, and fused kernel chains. + +--- + +## Quick Reference + +| ATOM Class | File | AITER Kernel / Import | Purpose | +|---|---|---|---| +| `LinearBase` | `linear.py` | `tgemm.mm`, `gemm_a8w8`, `gemm_a8w8_bpreshuffle`, `gemm_a8w8_blockscale_bpreshuffle`, `gemm_a4w4` | Quantized linear dispatch | +| `ColumnParallelLinear` | `linear.py` | (inherits `LinearBase`) | Column-sharded TP linear | +| `RowParallelLinear` | `linear.py` | (inherits `LinearBase`) | Row-sharded TP linear | +| `QKVParallelLinear` | `linear.py` | (inherits `ColumnParallelLinear`) | Fused Q/K/V projection | +| `MergedColumnParallelLinear` | `linear.py` | (inherits `LinearBase`) | Merged gate+up projection | +| `Attention` | `base_attention.py` | `unified_attention_with_output_base` (custom op) | Unified attention entry | +| MHA `Attention` | `attention_mha.py` | `flash_attn_varlen_func`, `pa_fwd_asm`, `pa_persistent_fwd`, `pa_decode_gluon` | Multi-head attention | +| `MLAAttention` | `attention_mla.py` | `mla_decode_fwd`, `mla_prefill_fwd`, `concat_and_cache_mla`, `fused_qk_rope_concat_and_cache_mla` | Multi-head latent attention | +| `FusedMoE` | `moe.py` | `aiter.fused_moe.fused_moe`, `asm_moe` | Mixture of experts | +| `RMSNorm` | `layernorm.py` | `rmsnorm2d_fwd`, `rmsnorm2d_fwd_with_add`, `fused_add_rmsnorm_pad` | RMS normalization | +| `LayerNorm` | `layernorm.py` | `layernorm2d_fwd`, `layernorm2d_fwd_with_add` | Layer normalization | +| `SiluAndMul` | `activation.py` | `aiter.silu_and_mul` | SiLU gated activation | +| `VocabParallelEmbedding` | `embed_head.py` | `F.embedding` + TP all-reduce | Vocab embedding | +| `ParallelLMHead` | `embed_head.py` | `tgemm.mm` + `tensor_model_parallel_all_gather` | LM output head | +| `RotaryEmbedding` | `rotary_embedding.py` | `aiter.rope_cached_positions_2c_fwd_inplace` | Rotary position embedding | +| `Sampler` | `sampler.py` | `aiter.mixed_sample_outer_exponential`, `aiter.ops.triton.topk.topk`, `aiter.ops.triton.softmax.softmax` | Token sampling | +| `RejectionSampler` | `rejection_sampler.py` | Triton `rejection_greedy_sample_kernel` | Speculative decoding | + +--- + +## 1. AITER Integration Overview + +ATOM is a thin model-level inference engine. Every compute-heavy operation delegates to an AITER kernel. The general pattern is: + +1. An ATOM `nn.Module` owns model weights and configuration. +2. Its `forward()` method selects the appropriate AITER function based on quantization type, parallelism settings, and phase (prefill vs. decode). +3. Results are optionally reduced across tensor-parallel (TP) or data-parallel (DP) groups. + +### AITER Kernel Mapping Table + +| ATOM Wrapper | AITER Function / Import Path | Backend Type | +|---|---|---| +| `LinearBase.forward` (No quant) | `aiter.tuned_gemm.tgemm.mm` | hipBLASLt | +| `LinearBase.forward` (per_Tensor FP8) | `aiter.tuned_gemm.tgemm.mm` with scales | hipBLASLt | +| `LinearBase.forward` (per_Token INT8) | `aiter.gemm_a8w8` | CK | +| `LinearBase.forward` (per_Token FP8) | `aiter.gemm_a8w8_bpreshuffle` | CK | +| `LinearBase.forward` (per_1x128 FP8) | `aiter.gemm_a8w8_blockscale_bpreshuffle` | CK | +| `LinearBase.forward` (per_1x32 MXFP4) | `aiter.gemm_a4w4` | CK | +| MHA prefill | `aiter.flash_attn_varlen_func` | ASM / CK | +| MHA decode (ASM) | `aiter.pa_fwd_asm` | ASM | +| MHA decode (persistent ASM) | `aiter.pa_persistent_fwd` | ASM | +| MHA decode (Triton) | `aiter.ops.triton.gluon.pa_decode_gluon` | Triton | +| MHA prefill (Triton unified) | `aiter.ops.triton.unified_attention.unified_attention` | Triton | +| MLA decode | `aiter.mla.mla_decode_fwd` | ASM | +| MLA prefill | `aiter.mla.mla_prefill_fwd` | ASM | +| MLA KV cache | `aiter.concat_and_cache_mla` | CK | +| RoPE | `aiter.rope_cached_positions_2c_fwd_inplace` | Triton | +| RMSNorm | `aiter.rmsnorm2d_fwd` | CK | +| SiLU+Mul | `aiter.silu_and_mul` | CK | +| TopK routing | `aiter.topk_softmax`, `aiter.grouped_topk`, `aiter.biased_grouped_topk` | CK | +| Sampling | `aiter.mixed_sample_outer_exponential` | CK | +| FusedMoE | `aiter.fused_moe.fused_moe` | CK | +| ASM MoE | `aiter.fused_moe_bf16_asm.asm_moe` | ASM | +| Quantization | `aiter.get_hip_quant(QuantType)` | CK / Triton | + +--- + +## 2. Linear Operations + +All linear layers inherit from `LinearBase` in `atom/model_ops/linear.py`. + +### 2.1 Class Hierarchy + +``` +LinearBase (nn.Module) + +-- ReplicatedLinear # No TP sharding + | +-- MergedReplicatedLinear + +-- ColumnParallelLinear # tp_dim=0, shard output + | +-- QKVParallelLinear # Fused Q/K/V with per-head sharding + +-- MergedColumnParallelLinear # tp_dim=0, merged gate+up + +-- RowParallelLinear # tp_dim=1, shard input, optional all-reduce +``` + +### 2.2 Quantization Dispatch + +`LinearBase.forward()` dispatches to different GEMM kernels based on `QuantType`: + +| `QuantType` | Weight dtype | GEMM Kernel | Scale Shape | +|---|---|---|---| +| `No` | BF16/FP16 | `tgemm.mm` (hipBLASLt) | None | +| `per_Tensor` | FP8 | `tgemm.mm` with `scale_a`, `scale_b` | `[num_partitions, 1]` | +| `per_Token` (INT8) | INT8 | `gemm_a8w8` | `[output_size, 1]` | +| `per_Token` (FP8) | FP8 | `gemm_a8w8_bpreshuffle` | `[output_size, 1]` | +| `per_1x128` | FP8 | `gemm_a8w8_blockscale_bpreshuffle` | `[ceil(N/128), ceil(K/128)]` | +| `per_1x32` | MXFP4 (`fp4x2`) | `gemm_a4w4` | `[N, ceil(K/32)]` (e8m0) | + +When `x_scale` is not provided, the input is dynamically quantized via `get_hip_quant(quant_type)`. + +### 2.3 Tensor Parallel Sharding + +- **ColumnParallelLinear** (`tp_dim=0`): Shards weight rows (output dimension) across GPUs. Each GPU owns `output_size / tp_size` rows. +- **RowParallelLinear** (`tp_dim=1`): Shards weight columns (input dimension). If `reduce_results=True`, output is all-reduced across TP group. +- **QKVParallelLinear**: Extends `ColumnParallelLinear` with per-head sharding. Q heads are evenly divided; KV heads are either divided or replicated when `num_kv_heads < tp_size`. +- **MergedColumnParallelLinear**: Handles gate and up projections merged into a single weight with `output_sizes` as a list (e.g., `[intermediate_size, intermediate_size]`). + +### 2.4 Weight Processing + +After loading, `process_weights_after_loading()` handles: +- **e4m3fn to e4m3fnuz normalization** (AMD FP8 format conversion). +- **Weight reshuffling** via `shuffle_weights()` for pre-shuffled GEMM kernels. +- **Scale reshuffling** via `fp4_utils.e8m0_shuffle()` for MXFP4 block scales. +- **Per-tensor requantization** via `requantize_with_max_scale()` when multiple output partitions have separate scales. + +--- + +## 3. Attention Operations + +### 3.1 Base: `Attention` (`base_attention.py`) + +The top-level `Attention` class in `base_attention.py` is a dispatcher. It: + +1. Selects the backend via `get_attn_backend()` from `atom/utils/selector.py`. +2. Instantiates the backend's implementation class (`impl_cls`). +3. Registers itself in `compilation_config.static_forward_context` under `layer_name`. +4. On `forward()`, calls `torch.ops.aiter.unified_attention_with_output_base`, which is a custom op decorated with `@mark_spliting_op` -- this prevents `torch.compile` from tracing into attention internals, enabling full-graph capture. + +Backend selection logic (in `selector.py`): + +| Condition | Backend Class | Implementation | +|---|---|---| +| `use_mla=True` | `AiterMLABackend` | `MLAAttention` from `attention_mla.py` | +| `use_mla=False` | `AiterBackend` | `Attention` from `attention_mha.py` | + +### 3.2 Multi-Head Attention (`attention_mha.py`) + +The MHA `Attention` class handles standard models (Llama, Qwen3, Mixtral, etc.). + +**Forward flow:** + +1. Reshape Q, K, V to `[num_tokens, num_heads, head_dim]`. +2. Apply RoPE + KV cache write via `rope_cache()`. +3. Dispatch to the appropriate backend via `dispatch_backend()`. + +**RoPE + KV cache paths:** + +| Condition | Kernel Chain | +|---|---| +| `q_norm` + `k_norm` + `rotary_emb` present | `fused_qk_norm_rope_cache_quant_shuffle` (single fused kernel for QK norm, RoPE, cache write, optional FP8 quant) | +| Triton path (`sliding_window != -1` or `head_dim != 128`) + `rotary_emb` | `fused_qk_rope_reshape_and_cache` (Triton fused RoPE + reshape + cache) | +| ASM path + `rotary_emb` | `rotary_emb(position, q, k)` then `reshape_and_cache` or `reshape_and_cache_with_pertoken_quant` | + +**Attention dispatch:** + +| Phase | Condition | Method | AITER Kernel | +|---|---|---|---| +| Prefill | Always | `prefill_attention` | `aiter.flash_attn_varlen_func` | +| Decode | `use_triton_attn=True` | `paged_attention_triton` | `torch.ops.aiter.pa_decode_gluon` | +| Decode | `block_size == 1024` | `paged_attention_persistent_asm` | `aiter.pa_persistent_fwd` | +| Decode | Default | `paged_attention_asm` | `aiter.pa_fwd_asm` | + +The `use_triton_attn` flag is set when `sliding_window != -1` or `head_dim != 128`. + +### 3.3 Multi-head Latent Attention (`attention_mla.py`) + +`MLAAttention` implements DeepSeek's MLA with a compressed KV representation. Key data structures: + +```python +@dataclass +class MLAModules: + q_lora_rank: Optional[int] + kv_lora_rank: int + qk_nope_head_dim: int + qk_rope_head_dim: int + qk_head_dim: int + v_head_dim: int + rotary_emb: torch.nn.Module + q_proj: Optional[torch.nn.Module] + kv_b_proj: torch.nn.Module + o_proj: torch.nn.Module + indexer: Optional[torch.nn.Module] +``` + +**Forward flow:** + +1. If prefill and not sparse: Standard MHA-style prefill with `flash_attn_varlen_func`, preceded by `kv_b_proj` GEMM to produce K_nope and V from compressed `kv_c_normed`. +2. Otherwise: Fused Q projection + K up-projection via batched FP8/FP4 BMM (`_q_proj_and_k_up_proj`), then: + - `fused_qk_rope_concat_and_cache_mla` writes to KV cache. + - Decode: `mla_decode_fwd` (ASM persistent MLA kernel). + - Prefill (sparse): `mla_prefill_fwd`. +3. V up-projection + O projection via batched BMM (`_v_up_proj_and_o_proj`). + +**Batched GEMM backends for MLA projections:** + +| Condition | Kernel | +|---|---| +| `ATOM_USE_TRITON_MXFP4_BMM=True` | `batched_gemm_a16wfp4` (Triton FP4 BMM) | +| Default | `batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant` (Triton FP8 BMM) | + +**Prefill GEMM optimizations** (for `kv_b_proj`): + +| Condition | Kernel | +|---|---| +| `ATOM_USE_TRITON_GEMM=True` + FP4 weights | `fused_gemm_afp4wfp4_preshuffle_split_cat` (GEMM + split K/V + cat rope in one kernel) | +| `ATOM_USE_TRITON_GEMM=True` + FP8 weights | `fused_gemm_a8w8_blockscale_preshuffle_split_cat` | +| Default | `kv_b_proj(kv_c_normed)` then manual split + cat | + +### 3.4 Backend Abstraction (`attentions/backends.py`) + +The `AttentionBackend` abstract class defines three required methods: + +- `get_name()` -- Returns backend identifier string. +- `get_builder_cls()` -- Returns the `AttentionMetadataBuilder` subclass. +- `get_impl_cls()` -- Returns the attention implementation class. + +`CommonAttentionBuilder` provides shared metadata preparation (slot mapping, block tables, cumulative sequence lengths) used by both `AiterBackend` and `AiterMLABackend`. + +### 3.5 KV Cache Operations + +| Operation | AITER Kernel | Used By | +|---|---|---| +| Standard KV cache write | `aiter.reshape_and_cache` | MHA (BF16 KV) | +| FP8 KV cache write | `aiter.reshape_and_cache_with_pertoken_quant` | MHA (FP8 KV) | +| MLA KV cache write | `aiter.concat_and_cache_mla` | MLA prefill | +| Fused QK RoPE + MLA cache | `aiter.fused_qk_rope_concat_and_cache_mla` | MLA decode | + +--- + +## 4. Mixture of Experts (MoE) + +### 4.1 `FusedMoE` Class (`moe.py`) + +`FusedMoE` is the top-level MoE module. It handles: +- Expert routing via `select_experts()`. +- Weight creation and quantization dispatch via `quant_method`. +- Tensor/Expert/Data parallelism via `FusedMoEParallelConfig`. +- Optional shared expert fusion and MORI communication. + +**Constructor parameters:** +```python +FusedMoE( + num_experts: int, # Global number of experts + top_k: int, # Experts per token + hidden_size: int, # Input hidden dimension + intermediate_size: int, # Expert intermediate dimension + reduce_results: bool, # Whether to all-reduce output + renormalize: bool, # Renormalize routing weights + use_grouped_topk: bool, # Use grouped top-k (DeepSeek) + activation: ActivationType, # Silu, Gelu, Swiglu, etc. + ... +) +``` + +### 4.2 Quantization Methods + +`FusedMoE` selects a `quant_method` at construction time: + +| Quant Config | Method Class | GEMM Kernel | +|---|---|---| +| `QuantType.No` | `UnquantizedFusedMoEMethod` | `aiter.fused_moe.fused_moe` | +| FP8 (`dtypes.fp8`) | `Fp8MoEMethod` | `aiter.fused_moe.fused_moe` with quant_type | +| FP8 compressed-tensors | `CompressedTensorsFp8MoEMethod` | `aiter.fused_moe.fused_moe` or `asm_moe` | +| MXFP4 (`dtypes.fp4x2`) | `Mxfp4MoEMethod` | `aiter.fused_moe.fused_moe` or Triton `triton_kernel_moe_forward` | + +The ASM MoE path (`asm_moe` from `aiter.fused_moe_bf16_asm`) is used by FP8 methods and supports `a16` mode where activations remain in BF16/FP16 while weights are FP8/INT8. + +### 4.3 TopK Routing (`topK.py`) + +| Routing Function | AITER Kernel | Used For | +|---|---|---| +| `rocm_aiter_topk_softmax` | `aiter.topk_softmax` | Standard top-k (Mixtral) | +| `rocm_aiter_grouped_topk` | `aiter.grouped_topk` | Grouped top-k (DeepSeek) | +| `rocm_aiter_biased_grouped_topk` | `aiter.biased_grouped_topk` | Biased grouped top-k (DeepSeek V3) | + +**Shared expert fusion:** When `is_rocm_aiter_fusion_shared_expert_enabled()` returns `True`, the top-k buffers are extended with shared expert IDs appended after routed expert IDs. This allows shared expert computation to be fused into the same MoE kernel call. The metadata is initialized via `init_aiter_topK_meta_data()`. + +### 4.4 `FusedMoEParallelConfig` + +```python +@dataclass +class FusedMoEParallelConfig: + tp_size: int # Tensor parallel size + dp_size: int # Data parallel size + ep_size: int # Expert parallel size + tp_rank: int + dp_rank: int + ep_rank: int + use_ep: bool # Whether expert parallelism is active + local_ep_size: int # Local EP size (GPUs per node * TP) +``` + +Key properties: +- `use_all2all_kernels`: `True` when `dp_size > 1`, EP is enabled, and MORI is available. +- `use_mori_kernels`: Always `True` (currently). + +### 4.5 MORI Integration (`fused_moe/mori_prepare_finalize.py`) + +MORI (MoE Router Infrastructure) provides all-to-all communication kernels for expert parallelism. `MoriPrepareAndFinalize` implements: + +- `prepare()`: Dispatches tokens to remote experts via `mori_op.dispatch()`. Optionally quantizes activations to FP8 before dispatch. +- `finalize()`: Combines expert outputs via `mori_op.combine()` and copies results back. + +The `FusedMoEModularKernel` orchestrates the prepare-compute-finalize pipeline. + +### 4.6 MoE Quantization Config (`fused_moe/config.py`) + +`FusedMoEQuantConfig` describes activation and weight quantization for MoE layers: + +```python +@dataclass +class FusedMoEQuantConfig: + _a1: FusedMoEQuantDesc # First activation (input to gate_up) + _a2: FusedMoEQuantDesc # Second activation (input to down_proj) + _w1: FusedMoEQuantDesc # gate_up_proj weights + _w2: FusedMoEQuantDesc # down_proj weights +``` + +Factory functions: +- `fp8_w8a8_moe_quant_config()` -- FP8 weights and activations. +- `mxfp4_w4a16_moe_quant_config()` -- MXFP4 weights, unquantized activations. +- `FUSED_MOE_UNQUANTIZED_CONFIG` -- No quantization. + +### 4.7 Triton MoE Fallback (`fused_moe_triton.py`) + +`triton_kernel_moe_forward()` provides a Triton-based MoE path using the `triton_kernels` library. It uses `routing()` for expert assignment and `matmul_ogs()` for the expert GEMM. This path is currently used for MXFP4 MoE on GFX94x hardware. + +--- + +## 5. Normalization + +### 5.1 `RMSNorm` (`layernorm.py`) + +`RMSNorm` supports multiple forward paths depending on configuration flags: + +| Condition | Kernel / Path | Returns | +|---|---|---| +| `x_pad_to_multiple > 0`, no residual | `fused_rmsnorm_pad_` (Triton `fused_add_rmsnorm_pad`) | Padded output | +| `x_pad_to_multiple > 0`, with residual | `fused_add_rmsnorm_pad_` | (output, residual) | +| `fused_allreduce=True` and `tp_size > 1` | `tensor_model_parallel_fused_allreduce_rmsnorm` | (output, residual) | +| `fused_quant=True` and `x_scale` provided | `fused_rms_fp8_per_tensor_static_quant` | (FP8 output, scale) | +| `fused_quant=True` and `per_1x32` | `fused_rms_mxfp4_quant` | (MXFP4 output, scale) | +| Default, no residual | `rmsnorm2d_fwd` | Output | +| Default, with residual | `rmsnorm2d_fwd_with_add` | (output, residual) | + +Constructor parameters: +```python +RMSNorm( + dim: int, + eps: float = 1e-6, + x_pad_to_multiple: int = 0, + fused_allreduce: bool = False, + fused_quant: bool = False, + quant_config: Optional[QuantizationConfig] = None, +) +``` + +### 5.2 `LayerNorm` (`layernorm.py`) + +`LayerNorm` wraps `layernorm2d_fwd` and `layernorm2d_fwd_with_add` (with bias support): + +```python +LayerNorm(dim: int, eps: float = 1e-6) +``` + +- Without residual: `layernorm2d_fwd(x, weight, bias, eps)` +- With residual: `layernorm2d_fwd_with_add(out, x, residual, residual_out, weight, bias, eps)` + +--- + +## 6. Activation Functions + +### 6.1 `SiluAndMul` (`activation.py`) + +`SiluAndMul` computes `SiLU(x_first_half) * x_second_half`. It splits the last dimension in half. + +| Condition | Kernel | Output | +|---|---|---| +| `fused_quant=True` + `x_scale` provided (FP8) | `fused_silu_mul_fp8_per_tensor_static_quant` | `(FP8 output, scale)` | +| `fused_quant=True` + `per_1x32` (MXFP4) | `fused_reduce_act_mul_and_mxfp4_quant` (via `mxfp4_act_mul_quant_fuse`) | `(MXFP4 output, scale)` | +| Default | `aiter.silu_and_mul(out, x)` | BF16 output | + +Constructor: +```python +SiluAndMul( + fused_quant: bool = False, + quant_config: Optional[QuantizationConfig] = None, +) +``` + +--- + +## 7. Embedding & Output Head + +### 7.1 `VocabParallelEmbedding` (`embed_head.py`) + +Partitions the vocabulary across TP ranks. Each rank holds `num_embeddings / tp_size` rows. + +**Forward:** +1. Mask input token IDs to this rank's partition range `[vocab_start_idx, vocab_end_idx)`. +2. `F.embedding()` on local partition. +3. Zero out out-of-range positions. +4. `all_reduce()` across TP group. + +### 7.2 `ParallelLMHead` (`embed_head.py`) + +Extends `VocabParallelEmbedding` for the output projection. Key differences: + +- **Forward** extracts only the last token per sequence during prefill (via `cu_seqlens_q[1:] - 1`). +- Uses `tgemm.mm(x, self.weight, self.bias)` for the logit computation (not `F.linear`). +- Calls `tensor_model_parallel_all_gather()` to gather logits across TP ranks. + +--- + +## 8. Rotary Position Embedding (RoPE) + +### 8.1 `RotaryEmbedding` (`rotary_embedding.py`) + +Precomputes cos/sin caches at initialization and applies RoPE in-place. + +**Constructor:** +```python +RotaryEmbedding( + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool = True, + dtype: Optional[torch.dtype] = None, +) +``` + +**Forward:** Calls `aiter.rope_cached_positions_2c_fwd_inplace(query_, key_, cos, sin, positions, rotate_style, ...)` which applies RoPE to Q and K tensors in-place using precomputed caches indexed by position IDs. + +### 8.2 `get_rope()` Factory + +```python +get_rope(head_size, rotary_dim, max_position, base, rope_scaling=None) +``` + +Returns a cached `RotaryEmbedding` instance. Currently `rope_scaling` must be `None`. + +### 8.3 Integration in Attention + +- **MHA** (`attention_mha.py`): RoPE is applied during the `rope_cache()` phase, either via the fused `fused_qk_norm_rope_cache_quant_shuffle` kernel, via `fused_qk_rope_reshape_and_cache`, or via standalone `rotary_emb(position, q, k)`. +- **MLA** (`attention_mla.py`): RoPE is applied to `q_pe` and `k_rope` tensors. During decode, this is fused into `fused_qk_rope_concat_and_cache_mla`. During prefill, it is applied via `self.rotary_emb(positions, prefill_q_pe, k_rope)`. + +--- + +## 9. Sampling + +### 9.1 `Sampler` (`sampler.py`) + +Unified sampling supporting both greedy (temperature=0) and random (temperature>0) sampling in a single kernel call. + +**Forward:** +```python +def forward(self, logits, temperatures) -> sampled_tokens: + mixed_sample_outer_exponential(sampled_tokens, logits, exponential, temperatures, eps) +``` + +`aiter.mixed_sample_outer_exponential` performs temperature-scaled exponential sampling: it divides logits by temperature, then uses the Gumbel-max trick with pre-generated exponential random variates. + +**Fallback methods** (currently unreachable due to early return): +- `greedy_sample()`: `aiter.ops.triton.topk.topk(logits, 1)` +- `random_sample()`: `aiter.ops.triton.softmax.softmax(logits)` followed by exponential sampling and `topk`. + +### 9.2 `RejectionSampler` (`rejection_sampler.py`) + +Implements rejection sampling for speculative decoding (MTP). Given draft token IDs and target model logits: + +1. Computes `target_argmax = target_logits.argmax(dim=-1)`. +2. Runs a Triton kernel `rejection_greedy_sample_kernel` that sequentially compares draft tokens against target argmax, accepting until first mismatch. +3. On full acceptance, appends the bonus token. +4. Returns `(output_token_ids, num_bonus_tokens)`. + +--- + +## 10. Fused Kernel Chains + +ATOM uses fused kernels to reduce memory traffic by combining multiple operations into a single kernel launch. + +| Fused Operation | Components | Controlled By | AITER Kernel | +|---|---|---|---| +| RMSNorm + FP8 quant | RMSNorm, per-tensor FP8 static quant | `RMSNorm(fused_quant=True)` + `x_scale` | `fused_rms_fp8_per_tensor_static_quant` | +| RMSNorm + MXFP4 quant | RMSNorm, per-1x32 MXFP4 quant | `RMSNorm(fused_quant=True)` + `QuantType.per_1x32` | `fused_rms_mxfp4_quant` | +| RMSNorm + add + pad | Residual add, RMSNorm, output padding | `RMSNorm(x_pad_to_multiple>0)` | `fused_add_rmsnorm_pad` | +| AllReduce + RMSNorm | TP all-reduce, RMSNorm | `RMSNorm(fused_allreduce=True)` | `tensor_model_parallel_fused_allreduce_rmsnorm` | +| SiLU + mul + FP8 quant | SiLU activation, multiply, FP8 quant | `SiluAndMul(fused_quant=True)` + `x_scale` | `fused_silu_mul_fp8_per_tensor_static_quant` | +| SiLU + mul + MXFP4 quant | SiLU activation, multiply, MXFP4 quant | `SiluAndMul(fused_quant=True)` + `QuantType.per_1x32` | `fused_reduce_act_mul_and_mxfp4_quant` | +| QK norm + RoPE + cache + quant | Q/K norm, RoPE, KV cache write, optional FP8 quant, weight shuffle | `q_norm` + `k_norm` + `rotary_emb` all present | `fused_qk_norm_rope_cache_quant_shuffle` | +| RoPE + reshape + cache | RoPE, K reshape, KV cache write | Triton attention path | `fused_qk_rope_reshape_and_cache` | +| QK RoPE + MLA cache | Q RoPE, KV concat, MLA cache write, FP8 quant | MLA decode path | `fused_qk_rope_concat_and_cache_mla` | +| GEMM + split + cat (FP4) | KV_b_proj GEMM, split K_nope/V, cat K_rope | `ATOM_USE_TRITON_GEMM=True` + FP4 weights | `fused_gemm_afp4wfp4_preshuffle_split_cat` | +| GEMM + split + cat (FP8) | KV_b_proj GEMM, split K_nope/V, cat K_rope | `ATOM_USE_TRITON_GEMM=True` + FP8 weights | `fused_gemm_a8w8_blockscale_preshuffle_split_cat` | +| FP8 BMM + RoPE + cache (MLA) | Batched FP8 BMM, RoPE, MLA KV cache write | MLA decode with FP8 | `fused_fp8_bmm_rope_cat_and_cache_mla` | +| FP4 BMM + RoPE + cache (MLA) | Batched FP4 BMM, RoPE, MLA KV cache write | MLA decode with MXFP4 | `fused_fp4_bmm_rope_cat_and_cache_mla` | + +--- + +## Source Files + +### `atom/model_ops/` + +| File | Description | +|---|---| +| `linear.py` | `LinearBase`, `ColumnParallelLinear`, `RowParallelLinear`, `QKVParallelLinear`, `MergedColumnParallelLinear`, `ReplicatedLinear`, `MergedReplicatedLinear` | +| `activation.py` | `SiluAndMul` with fused FP8/MXFP4 quantization | +| `layernorm.py` | `RMSNorm`, `LayerNorm` with fused allreduce/quant/pad variants | +| `base_attention.py` | Top-level `Attention` dispatcher with custom op registration | +| `attention_mha.py` | MHA implementation: prefill (flash), decode (ASM/Triton paged attention) | +| `attention_mla.py` | `MLAAttention`, `MLAModules` -- DeepSeek MLA with compressed KV | +| `moe.py` | `FusedMoE`, `FusedMoEParallelConfig`, `UnquantizedFusedMoEMethod`, `Fp8MoEMethod`, `Mxfp4MoEMethod`, `CompressedTensorsFp8MoEMethod` | +| `fused_moe_triton.py` | `triton_kernel_moe_forward` -- Triton MoE via `triton_kernels` library | +| `embed_head.py` | `VocabParallelEmbedding`, `ParallelLMHead` | +| `rotary_embedding.py` | `RotaryEmbedding`, `get_rope` | +| `topK.py` | `rocm_aiter_topk_softmax`, `rocm_aiter_grouped_topk`, `init_aiter_topK_meta_data` | +| `sampler.py` | `Sampler` -- unified greedy/random sampling | +| `rejection_sampler.py` | `RejectionSampler` -- speculative decoding rejection sampling | +| `base_config.py` | `QuantizeMethodBase` abstract class | +| `utils.py` | Helper utilities: `shuffle_weights`, `normalize_e4m3fn_to_e4m3fnuz`, `per_tensor_dequantize`, etc. | + +### `atom/model_ops/attentions/` + +| File | Description | +|---|---| +| `backends.py` | `AttentionBackend`, `AttentionMetadataBuilder`, `CommonAttentionBuilder`, `AttentionImpl` abstract classes | +| `aiter_attention.py` | `AiterBackend`, `AiterAttentionMetadataBuilder` -- MHA backend with persistent ASM paged attention support | +| `aiter_mla.py` | `AiterMLABackend`, `AiterMLAMetadataBuilder` -- MLA backend with sparse attention support | + +### `atom/model_ops/fused_moe/` + +| File | Description | +|---|---| +| `config.py` | `FusedMoEConfig`, `FusedMoEQuantConfig`, `FusedMoEQuantDesc`, `GroupShape`, factory functions (`fp8_w8a8_moe_quant_config`, `mxfp4_w4a16_moe_quant_config`) | +| `modular_kernel.py` | `FusedMoEModularKernel`, `FusedMoEPrepareAndFinalize`, `ExpertTokensMetadata` -- modular MoE kernel pipeline | +| `mori_prepare_finalize.py` | `MoriPrepareAndFinalize` -- MORI all-to-all dispatch/combine for expert parallelism | +| `utils.py` | MoE utility functions | + +### `atom/utils/` + +| File | Description | +|---|---| +| `selector.py` | `get_attn_backend()` -- selects `AiterBackend` or `AiterMLABackend` based on `use_mla` flag | diff --git a/docs/model_support_guide.md b/docs/model_support_guide.md new file mode 100644 index 000000000..285700568 --- /dev/null +++ b/docs/model_support_guide.md @@ -0,0 +1,320 @@ +# ATOM Model Support Guide + +ATOM (AiTer Optimized Model) is AMD's lightweight LLM inference engine built on AITER kernels for ROCm/HIP GPUs. This guide covers the supported model architectures, weight loading, and how to add new models. + +## Quick Reference + +The model registry lives in `atom/model_engine/model_runner.py` as `support_model_arch_dict`: + +```python +support_model_arch_dict = { + "Qwen3ForCausalLM": "atom.models.qwen3.Qwen3ForCausalLM", + "Qwen3MoeForCausalLM": "atom.models.qwen3_moe.Qwen3MoeForCausalLM", + "LlamaForCausalLM": "atom.models.llama.LlamaForCausalLM", + "MixtralForCausalLM": "atom.models.mixtral.MixtralForCausalLM", + "DeepseekV3ForCausalLM": "atom.models.deepseek_v2.DeepseekV2ForCausalLM", + "DeepseekV32ForCausalLM": "atom.models.deepseek_v2.DeepseekV2ForCausalLM", + "GptOssForCausalLM": "atom.models.gpt_oss.GptOssForCausalLM", + "Glm4MoeForCausalLM": "atom.models.glm4_moe.Glm4MoeForCausalLM", +} +``` + +ATOM resolves the HuggingFace `architectures` field from a model's `config.json` against this dictionary. If the architecture string matches a key, ATOM imports and instantiates the corresponding class. + +--- + +## 1. Supported Model Architectures + +| HF Architecture | ATOM Module | ATOM Class | MoE | MLA | Key Features | +|---|---|---|---|---|---| +| `Qwen3ForCausalLM` | `atom.models.qwen3` | `Qwen3ForCausalLM` | No | No | GQA, QK norm, RoPE | +| `Qwen3MoeForCausalLM` | `atom.models.qwen3_moe` | `Qwen3MoeForCausalLM` | Yes | No | GQA, QK norm, FusedMoE, sparse+dense layer mixing, QK norm+RoPE+cache+quant fusion | +| `LlamaForCausalLM` | `atom.models.llama` | `LlamaForCausalLM` | No | No | GQA, RoPE, fused RMSNorm+quant, fused SiLU+mul+quant | +| `MixtralForCausalLM` | `atom.models.mixtral` | `MixtralForCausalLM` | Yes | No | GQA, RoPE, FusedMoE with TP sharding | +| `DeepseekV3ForCausalLM` | `atom.models.deepseek_v2` | `DeepseekV2ForCausalLM` | Yes | Yes | MLA attention, LoRA-compressed QKV, FusedMoE with shared experts, FP4/FP8 fused kernels | +| `DeepseekV32ForCausalLM` | `atom.models.deepseek_v2` | `DeepseekV2ForCausalLM` | Yes | Yes | Same as above with V3.2 index-based top-k routing | +| `GptOssForCausalLM` | `atom.models.gpt_oss` | `GptOssForCausalLM` | Yes | No | GQA, RoPE, sliding window attention (every other layer), attention sinks, bias in QKV and MoE | +| `Glm4MoeForCausalLM` | `atom.models.glm4_moe` | `Glm4MoeForCausalLM` | Yes | No | GQA, partial RoPE (0.5 factor), QK norm, shared+routed experts, sigmoid scoring, grouped top-k | + +**Note:** `DeepSeekMTP` (`atom.models.deepseek_mtp.DeepSeekMTP`) is not in the registry -- it is used exclusively as a speculative draft model for DeepSeek multi-token prediction and is loaded separately. + +--- + +## 2. Model Architecture Details + +### Qwen3 (`Qwen3ForCausalLM`) + +- **Architecture:** Dense transformer with Grouped-Query Attention (GQA). +- **Layer structure:** `Qwen3DecoderLayer` containing `Qwen3Attention` + `Qwen3MLP`. +- **Attention:** `QKVParallelLinear` for fused QKV projection, per-head QK RMSNorm (`q_norm`, `k_norm`), RoPE, `RowParallelLinear` for output projection. +- **MLP:** `MergedColumnParallelLinear` for gate+up projection, SiLU activation, `RowParallelLinear` for down projection. +- **Normalization:** RMSNorm on input and post-attention. + +### Qwen3-MoE (`Qwen3MoeForCausalLM`) + +- **Architecture:** Mixture-of-Experts transformer with GQA. +- **Layer structure:** `Qwen3MoeDecoderLayer` containing `Qwen3MoeAttention` + either `Qwen3MoeSparseMoeBlock` (MoE layers) or `Qwen3MoeMLP` (dense layers, controlled by `mlp_only_layers` and `decoder_sparse_step`). +- **Attention:** Same QKV structure as Qwen3 with QK norm. Supports QK norm + RoPE + cache + quant fusion when `ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION` is set -- this precomputes a joint `cos_sin_cache` and passes `q_norm`/`k_norm` to the `Attention` module. +- **MoE:** `FusedMoE` with `ReplicatedLinear` gate router. Supports allreduce+RMSNorm fusion (`ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION`). +- **Normalization:** RMSNorm with optional fused allreduce. + +### Llama (`LlamaForCausalLM`) + +- **Architecture:** Dense transformer with GQA. Covers Llama 2/3 and compatible architectures (InternLM, Mistral-Nemo via optional `head_dim`). +- **Layer structure:** `LlamaDecoderLayer` containing `LlamaAttention` + `LlamaMLP`. +- **Attention:** `QKVParallelLinear`, RoPE (NeoX or original style based on GGUF), per-layer sliding window support via `layer_types` config. +- **MLP:** `MergedColumnParallelLinear` for gate+up, SiLU+mul activation, `RowParallelLinear` for down. +- **Fused optimizations:** Controlled by environment variables: + - `ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT` -- fuses RMSNorm with FP8/MXFP4 quantization. + - `ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT` -- fuses SiLU+mul activation with quantization. +- **Pipeline parallelism:** Full PP support with `PPMissingLayer` placeholders and `IntermediateTensors` for cross-stage communication. Supports auxiliary hidden state extraction for speculative decoding. + +### Mixtral (`MixtralForCausalLM`) + +- **Architecture:** Sparse Mixture-of-Experts with GQA. +- **Layer structure:** `MixtralDecoderLayer` containing `MixtralAttention` + `MixtralMoE`. +- **Attention:** Standard GQA with `QKVParallelLinear`, RoPE (NeoX style), `RowParallelLinear`. +- **MoE:** `MixtralMoE` wraps `ReplicatedLinear` gate + `FusedMoE`. Experts are sharded across TP ranks with full reduce. Gate checkpoint names use `w1`/`w2`/`w3` convention (mapped to `gate_proj`/`down_proj`/`up_proj`). +- **Normalization:** RMSNorm. + +### DeepSeek V2/V3 (`DeepseekV2ForCausalLM`) + +- **Architecture:** MoE transformer with Multi-head Latent Attention (MLA). +- **Layer structure:** `DeepseekV2DecoderLayer` containing `DeepseekV2MLAAttention` + either `DeepseekV2MoE` (MoE layers) or `DeepseekV2MLP` (dense layers). +- **MLA Attention:** Uses LoRA-compressed QKV (`q_lora_rank`, `kv_lora_rank`), separate `qk_nope_head_dim` and `qk_rope_head_dim` for non-positional and rotary-embedded components. Backed by `MLAModules` from `atom.model_ops.attention_mla`. +- **MoE:** `DeepseekV2MoE` with routed + shared experts. Supports shared expert fusion (`is_rocm_aiter_fusion_shared_expert_enabled`), routed scaling factor fusion (`is_rocm_aiter_fuse_routed_scaling_factor`), and grouped top-k routing. +- **Fused optimizations:** + - `ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION` -- fuses input RMSNorm with FP8/FP4 quantization. + - `ATOM_ENABLE_DS_QKNORM_QUANT_FUSION` -- fuses QK norm with quantization. + - `ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION` -- fuses allreduce with RMSNorm. + - Dedicated Triton kernels for FP8 MQA logits (`fp8_mqa_logits`), paged MQA logits (`deepgemm_fp8_paged_mqa_logits`), and fused RMSNorm+quantization (`_fuse_rmsnorm_quant`). +- **V3.2 extension:** `DeepseekV32ForCausalLM` is an alias. The `DeepseekV2Model` detects V3.2 via `config.index_topk` and allocates an `topk_indices_buffer` for index-based routing. +- **Note:** `DeepseekV3ForCausalLM` is a subclass of `DeepseekV2ForCausalLM` (pass-through, no override). + +### DeepSeek MTP (`DeepSeekMTP`) + +- **Architecture:** Multi-Token Prediction draft model for speculative decoding. +- **Layer structure:** `DeepSeekMultiTokenPredictor` containing one or more `DeepSeekMultiTokenPredictorLayer`, each with `enorm` (embedding norm), `hnorm` (hidden state norm), `eh_proj` (linear projection joining embedded+hidden), `mtp_block` (a `DeepseekV2DecoderLayer`), and a `SharedHead` (norm + LM head). +- **Usage:** Not registered in `support_model_arch_dict`. Loaded separately with `spec_decode=True` in `load_model()`, which invokes `rewrite_spec_layer_name()` to remap MTP weight names (e.g., adding `.mtp_block.` prefix for transformer layer weights, remapping `embed_tokens` to top-level). +- **MTP layers start** at `config.num_hidden_layers` (i.e., the layer indices following the main model layers). + +### GPT-OSS (`GptOssForCausalLM`) + +- **Architecture:** MoE transformer with GQA and alternating sliding window attention. +- **Layer structure:** `TransformerBlock` containing `OAIAttention` + `MLPBlock`. +- **Attention:** `OAIAttention` with bias on QKV and output projections, attention sinks (learnable per-head parameters), and sliding window applied on even-indexed layers only. +- **MoE:** `MLPBlock` wraps `ReplicatedLinear` router (with bias) + `FusedMoE` with SwiGLU activation and bias support. Custom `weights_mapping` translates checkpoint names (`gate_up_proj_blocks` to `w13_weight`, etc.). +- **Normalization:** RMSNorm with eps=1e-5, post-attention norm uses `x_pad_to_multiple=256`. +- **Pipeline parallelism:** Supports auxiliary hidden state layers for EAGLE3 speculative decoding (`get_eagle3_aux_hidden_state_layers`). + +### GLM4-MoE (`Glm4MoeForCausalLM`) + +- **Architecture:** MoE transformer with GQA, shared + routed experts, partial RoPE. +- **Layer structure:** `Glm4MoeDecoderLayer` containing `Glm4MoeAttention` + either `Glm4MoE` (MoE layers, from `first_k_dense_replace` onward) or `Glm4MoeMLP` (dense layers). +- **Attention:** `Glm4MoeAttention` with optional QK norm (`use_qk_norm`), partial rotary factor of 0.5. +- **MoE:** `Glm4MoE` with sigmoid scoring, `e_score_correction_bias`, grouped top-k routing (`n_group`, `topk_group`), routed scaling factor. Shared experts handled separately or fused into `FusedMoE` via `is_rocm_aiter_fusion_shared_expert_enabled()`. Expert parallelism (EP) support built in. +- **Inherits:** `Glm4MixtureOfExperts` mixin for MoE metadata management and expert load balancing (EPLB) support. + +--- + +## 3. Weight Loading + +Weight loading is handled by `load_model()` in `atom/model_loader/loader.py`. + +### Function Signature + +```python +def load_model( + model: nn.Module, + model_name_or_path: str, + hf_config: AutoConfig, + load_dummy: bool = False, + spec_decode: bool = False, +): +``` + +### Loading Flow + +1. **SafeTensors iteration:** `safetensors_weights_iterator()` discovers and iterates over all `*.safetensors` files in the model directory (or downloads them from HuggingFace Hub via `download_weights_from_hf()`). Duplicate files are filtered using the `model.safetensors.index.json` weight map. Memory-mapped loading is used by default; set `ATOM_DISABLE_MMAP=true` to disable. + +2. **Weight name rewriting:** Each weight name goes through several transformations: + - `weight_scale_inv` is renamed to `weight_scale`. + - Model-specific `weights_mapping` (e.g., GPT-OSS maps `gate_up_proj_blocks` to `w13_weight`). + - For speculative decoding (`spec_decode=True`), MTP layer weights are rewritten via `rewrite_spec_layer_name()`. + - Shared expert fusion: when enabled, `mlp.shared_experts` is remapped to `mlp.experts.` so the shared expert is loaded as the last expert in the `FusedMoE` module. + +3. **Packed module resolution:** The `packed_modules_mapping` dict on each model class defines how HuggingFace checkpoint weight names map to ATOM's fused parameter names. For example, Llama maps: + ```python + "q_proj": ("qkv_proj", "q"), + "k_proj": ("qkv_proj", "k"), + "v_proj": ("qkv_proj", "v"), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + ``` + Each packed parameter has a `weight_loader` attribute that knows how to shard and place the weight into the correct slice. + +4. **Expert parameter loading:** If the model has a `get_expert_mapping()` method, expert weights are loaded using `FusedMoE.make_expert_params_mapping()`, which generates (param_name, weight_name, expert_id, shard_id) tuples. This handles per-expert sharding across TP ranks. + +5. **TP sharding:** Parallel linear layers (`ColumnParallelLinear`, `RowParallelLinear`, `QKVParallelLinear`) have custom `weight_loader` methods that automatically select the correct shard for the current TP rank during loading. The default fallback `default_weight_loader` handles simple cases where weights need to be sliced by TP rank. + +6. **Concurrent loading:** All weight loading calls are submitted to a `ThreadPoolExecutor` for parallel execution. + +7. **Post-processing:** After all weights are loaded, `process_weights_after_loading()` is called on each module (e.g., for weight pre-shuffling, scale computation), and `quant_method.process_weights_after_loading()` is invoked for quantized modules. For `FusedMoEMethodBase`, `init_prepare_finalize()` is also called. + +### Layers Beyond `num_hidden_layers` + +Weights for layers with index >= `config.num_hidden_layers` are skipped during normal loading. These layers (MTP layers) are only loaded when `spec_decode=True`. + +--- + +## 4. Adding a New Model + +Follow these steps to add support for a new model architecture: + +### Step 1: Create the Model File + +Create a new file in `atom/models/`, e.g., `atom/models/my_model.py`. Follow the existing patterns: + +```python +from atom.config import Config, QuantizationConfig +from atom.model_ops.base_attention import Attention +from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding +from atom.model_ops.layernorm import RMSNorm +from atom.model_ops.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from atom.models.utils import ( + IntermediateTensors, + PPMissingLayer, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) +from atom.utils.decorators import support_torch_compile +``` + +### Step 2: Implement Layer Classes + +Each model typically defines three core module classes: + +1. **Attention module** (e.g., `MyModelAttention`): + - Initialize `QKVParallelLinear` for query/key/value. + - Initialize `RowParallelLinear` for output projection. + - Set up rotary embeddings via `aiter.rotary_embedding.get_rope()`. + - Create `Attention` from `atom.model_ops.base_attention`. + +2. **MLP module** (e.g., `MyModelMLP`): + - Use `MergedColumnParallelLinear` for gate+up projections. + - Use `RowParallelLinear` for down projection. + - For MoE models, use `FusedMoE` from `atom.model_ops.moe`. + +3. **Decoder layer** (e.g., `MyModelDecoderLayer`): + - Combine attention + MLP with RMSNorm layers. + - Implement the forward pass with residual connections. + +### Step 3: Implement the Model and CausalLM Classes + +1. **Backbone model** (e.g., `MyModel`): + - Decorate with `@support_torch_compile`. + - Initialize `VocabParallelEmbedding`, decoder layers via `make_layers()`, and final `RMSNorm`. + - Support pipeline parallelism with `PPMissingLayer` and `IntermediateTensors`. + +2. **CausalLM wrapper** (e.g., `MyModelForCausalLM`): + - Define `packed_modules_mapping` to map checkpoint weight names to ATOM's fused parameter names. + - Initialize the backbone model and `ParallelLMHead`. + - Implement `forward()` (returns hidden states) and `compute_logits()` (returns logits via `lm_head`). + - If the model uses MoE, implement `get_expert_mapping()` returning `FusedMoE.make_expert_params_mapping(...)`. + +### Step 4: Register the Model + +Add an entry to `support_model_arch_dict` in `atom/model_engine/model_runner.py`: + +```python +support_model_arch_dict = { + ... + "MyModelForCausalLM": "atom.models.my_model.MyModelForCausalLM", +} +``` + +The key must exactly match the `architectures` field in the HuggingFace model's `config.json`. + +### Step 5: Handle Weight Loading + +Ensure your `packed_modules_mapping` correctly maps all checkpoint weight names that differ from ATOM's internal names. Common patterns: + +| Checkpoint Name | ATOM Parameter | Shard ID | +|---|---|---| +| `q_proj` | `qkv_proj` | `"q"` | +| `k_proj` | `qkv_proj` | `"k"` | +| `v_proj` | `qkv_proj` | `"v"` | +| `gate_proj` | `gate_up_proj` | `0` | +| `up_proj` | `gate_up_proj` | `1` | + +For MoE models, add `get_expert_mapping()` to delegate to `FusedMoE.make_expert_params_mapping()` with the correct gate/down/up projection names and expert count. + +If the checkpoint uses non-standard weight names (like GPT-OSS), define a `weights_mapping` class attribute to rename them at load time. + +--- + +## 5. Model-Specific Optimizations + +### Llama: Fused RMSNorm+Quant and SiLU+Mul+Quant + +Llama supports two AITER Triton fused kernel optimizations: + +- **`ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT`**: Fuses the RMSNorm normalization with FP8 or MXFP4 quantization in a single kernel call. Applied to both `input_layernorm` and `post_attention_layernorm`. Eliminates an extra read/write pass over the hidden states. + +- **`ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT`**: Fuses the SiLU activation, element-wise multiply, and quantization in the MLP. The `SiluAndMul` module receives the `fused_quant=True` flag and the quant config, producing quantized output directly for the down projection. + +Both are controlled by environment variables and read from `atom.utils.envs`. + +### DeepSeek V2/V3: MLA + Fused Input Norm + QK Norm Fusion + +DeepSeek models use Multi-head Latent Attention (MLA) with LoRA-compressed projections (`q_lora_rank`, `kv_lora_rank`). Several fusion optimizations are available: + +- **`ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION`**: Fuses the input RMSNorm with quantization. Implemented via `_fuse_rmsnorm_quant()` which dispatches to either `_fuse_rmsnorm_fp4_quant()` or `_fused_rms_fp8_group_quant()` based on the quant dtype. When enabled, the allreduce+RMSNorm fusion is disabled for `input_layernorm` but kept for `post_attention_layernorm`. + +- **`ATOM_ENABLE_DS_QKNORM_QUANT_FUSION`**: Fuses the Q/K LoRA layernorm with quantization via `_fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4()` or the FP8 variant, which performs the fused QKV-A projection, RMSNorm on Q and KV components, and quantization in a single fused operation. + +- **`ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION`**: Fuses tensor-parallel allreduce with RMSNorm. + +- **FP8 MQA logits**: `fp8_mqa_logits` and `deepgemm_fp8_paged_mqa_logits` implement FP8-precision attention score computation for MLA decode. + +- **FP4 support**: MXFP4 quantized GEMM kernels (`gemm_afp4wfp4_preshuffle`, `gemm_a16wfp4_preshuffle`) and FP4 block-scale BMM via `is_rocm_aiter_fp4bmm_enabled()`. + +### Qwen3-MoE: QK Norm + RoPE + Cache + Quant Fusion + +When `ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION` is enabled, the `Qwen3MoeAttention` module: +1. Precomputes a joint `cos_sin_cache` by concatenating cosine and sine RoPE caches. +2. Passes `q_norm` and `k_norm` directly to the `Attention` module. +3. The attention backend then fuses QK normalization, RoPE application, KV cache write, and optional quantization into a single kernel pass. + +Additionally, `ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION` fuses allreduce with RMSNorm for both attention output and MoE output, reducing communication overhead. + +### MTP: DeepSeek Multi-Token Prediction + +The `DeepSeekMTP` model serves as a speculative draft model: +- Each `DeepSeekMultiTokenPredictorLayer` takes the previous hidden state and the next token's embedding, normalizes both (`enorm`, `hnorm`), concatenates them, and passes through a linear projection (`eh_proj`) followed by a standard `DeepseekV2DecoderLayer`. +- The `SharedHead` provides per-layer norm + LM head for logit computation. +- For FP4 quantized main models, MTP blocks fall back to non-FP4 quantization config to maintain draft model accuracy. + +--- + +## Source Files + +| File | Description | +|------|-------------| +| `atom/model_engine/model_runner.py` | Model registry (`support_model_arch_dict`) and `ModelRunner` class | +| `atom/models/llama.py` | Llama model: `LlamaForCausalLM`, `LlamaModel`, `LlamaDecoderLayer`, `LlamaAttention`, `LlamaMLP` | +| `atom/models/qwen3.py` | Qwen3 model: `Qwen3ForCausalLM`, `Qwen3Model`, `Qwen3DecoderLayer`, `Qwen3Attention`, `Qwen3MLP` | +| `atom/models/qwen3_moe.py` | Qwen3-MoE model: `Qwen3MoeForCausalLM`, `Qwen3MoeModel`, `Qwen3MoeDecoderLayer`, `Qwen3MoeAttention`, `Qwen3MoeSparseMoeBlock`, `Qwen3MoeMLP` | +| `atom/models/deepseek_v2.py` | DeepSeek V2/V3 model: `DeepseekV2ForCausalLM`, `DeepseekV3ForCausalLM`, `DeepseekV2Model`, `DeepseekV2DecoderLayer`, `DeepseekV2MLAAttention`, `DeepseekV2MoE`, `DeepseekV2MLP` | +| `atom/models/deepseek_mtp.py` | DeepSeek MTP draft model: `DeepSeekMTP`, `DeepSeekMultiTokenPredictor`, `DeepSeekMultiTokenPredictorLayer`, `SharedHead` | +| `atom/models/mixtral.py` | Mixtral model: `MixtralForCausalLM`, `MixtralModel`, `MixtralDecoderLayer`, `MixtralAttention`, `MixtralMoE` | +| `atom/models/gpt_oss.py` | GPT-OSS model: `GptOssForCausalLM`, `GptOssModel`, `TransformerBlock`, `OAIAttention`, `MLPBlock` | +| `atom/models/glm4_moe.py` | GLM4-MoE model: `Glm4MoeForCausalLM`, `Glm4MoeModel`, `Glm4MoeDecoderLayer`, `Glm4MoeAttention`, `Glm4MoE`, `Glm4MoeMLP` | +| `atom/models/utils.py` | Model utilities: `IntermediateTensors`, `PPMissingLayer`, `make_layers`, `maybe_prefix`, `extract_layer_index`, `should_ignore_layer`, `get_quant_config_for_layer` | +| `atom/model_loader/loader.py` | Weight loading: `load_model`, `safetensors_weights_iterator`, `default_weight_loader` | +| `atom/model_loader/weight_utils.py` | Weight utilities: `download_weights_from_hf`, `set_weight_attrs`, `filter_duplicate_safetensors_files` | diff --git a/docs/scheduling_kv_cache_guide.md b/docs/scheduling_kv_cache_guide.md new file mode 100644 index 000000000..d732953de --- /dev/null +++ b/docs/scheduling_kv_cache_guide.md @@ -0,0 +1,595 @@ +# ATOM Scheduling & KV Cache Guide + +ATOM (AiTer Optimized Model) uses a prefill-first scheduler with paged KV cache block management to drive LLM inference on AMD ROCm/HIP GPUs. This guide covers the scheduling algorithm, batch construction, block-level KV cache management, prefix caching, postprocessing, speculative decoding integration, and sequence lifecycle. + +## Quick Reference + +| Class | File | Purpose | +|---|---|---| +| `Scheduler` | `atom/model_engine/scheduler.py` | Orchestrates prefill/decode scheduling, preemption, and postprocessing | +| `ScheduledBatch` | `atom/model_engine/scheduler.py` | Immutable snapshot of a scheduled batch sent to the model runner | +| `ScheduledBatchOutput` | `atom/model_engine/scheduler.py` | Holds sampled token IDs and draft token IDs returned from forward pass | +| `BlockManager` | `atom/model_engine/block_manager.py` | Manages paged KV cache blocks with allocation, deallocation, and prefix caching | +| `Block` | `atom/model_engine/block_manager.py` | Single KV cache block with ID, reference count, hash, and token IDs | +| `Sequence` | `atom/model_engine/sequence.py` | Tracks a single request through its lifetime (tokens, blocks, status, timing) | +| `SequenceStatus` | `atom/model_engine/sequence.py` | Enum: `WAITING`, `RUNNING`, `FINISHED`, `EXIT_ENGINE` | +| `SequenceType` | `atom/model_engine/sequence.py` | Enum: `DUMMY`, `PREFILL`, `DECODE` | +| `RequestOutput` | `atom/model_engine/request.py` | Dataclass streamed to clients with new tokens and finish status | +| `Config` | `atom/config.py` | Scheduling-related fields: `max_num_seqs`, `max_num_batched_tokens`, `kv_cache_block_size`, etc. | + +**Key config defaults:** + +| Field | Default | Description | +|---|---|---| +| `max_num_seqs` | 512 | Maximum sequences in a single batch | +| `max_num_batched_tokens` | 16384 | Maximum tokens scheduled in a single step | +| `kv_cache_block_size` | 16 | Tokens per KV cache block (must be multiple of 16, or 1) | +| `enable_prefix_caching` | `False` | Enable hash-based prefix block sharing | +| `scheduler_delay_factor` | 0.0 | Delay factor for batching prompt requests (0 = no delay) | +| `gpu_memory_utilization` | 0.9 | Fraction of GPU memory for KV cache | + +--- + +## 1. Scheduling Algorithm + +The scheduler implements a **prefill-first** policy: all waiting (prefill) requests are scheduled before any running (decode) requests. The entry point is `Scheduler.schedule()`, which returns a `(ScheduledBatch, dict[int, Sequence])` tuple or `None` if both queues are empty. + +### 1.1 Scheduler Initialization + +```python +class Scheduler: + def __init__(self, config: Config): + self.max_num_seqs = config.max_num_seqs + self.max_num_batched_tokens = config.max_num_batched_tokens + self.bos_token_id = config.bos_token_id + self.eos_token_id = config.eos_token_id + self.stop_token_ids = config.stop_token_ids + self.block_manager = BlockManager(config) + self.waiting: deque[Sequence] = deque() + self.running: deque[Sequence] = deque() + self.prev_time = 0.0 + self.prev_prompt = False + self.last_prompt_latency = 0.0 + self.delay_factor = config.scheduler_delay_factor + self.use_spec = config.speculative_config is not None + self.mtp_k: int = ( + config.speculative_config.num_speculative_tokens if self.use_spec else 0 + ) + self.total_draft_tokens = 0 + self.total_accepted_tokens = 0 +``` + +The scheduler maintains two deques -- `waiting` (pending prefill) and `running` (active decode) -- plus a `BlockManager` for KV cache allocation. + +### 1.2 Schedule Flow + +`Scheduler.schedule()` proceeds in two phases: + +**Phase 1 -- Prefill scheduling:** + +1. While the delay gate passes (`_passed_delay`), the waiting queue is non-empty, and `num_seqs_prefill < max_num_seqs`: + - Peek the first waiting sequence. + - Compute `num_new_tokens = seq.num_tokens - seq.num_cached_tokens` (prefix cache hits reduce new tokens). + - If `num_batched_tokens + num_new_tokens > max_num_batched_tokens` or `block_manager.can_allocate(seq)` returns `False`, break. + - Otherwise: allocate blocks, set `seq.status = RUNNING`, `seq.type = PREFILL`, move from `waiting` to `running`. +2. If any prefill sequences were scheduled, return the batch immediately (no decode mixing). + +**Phase 2 -- Decode scheduling (only when zero prefills were scheduled):** + +1. Pop sequences from `running` up to `max_num_seqs`. +2. For each sequence, check `block_manager.can_append(seq)`. +3. If a block cannot be appended, **preempt** the last running sequence (move it back to `waiting` with status `WAITING` and deallocate its blocks). +4. If the sequence has speculative draft tokens (`seq.spec_token_ids`), record them in `scheduled_spec_decode_tokens`. +5. Call `block_manager.may_append(seq, num_new_tokens)` where `num_new_tokens = mtp_k + 1`. +6. Re-insert all scheduled sequences back into `running` (preserving order). + +### 1.3 Delay Factor + +When `scheduler_delay_factor > 0`, the scheduler delays prefill scheduling to allow the waiting queue to accumulate more requests for better batching: + +```python +def _passed_delay(self, now: float) -> bool: + if self.prev_prompt: + self.last_prompt_latency = now - self.prev_time + self.prev_time, self.prev_prompt = now, False + if self.delay_factor > 0 and self.waiting: + earliest_arrival_time = min([seq.arrive_time for seq in self.waiting]) + passed_delay = (now - earliest_arrival_time) > ( + self.delay_factor * self.last_prompt_latency + ) or not self.running + else: + passed_delay = True + return passed_delay +``` + +A new prefill is scheduled only when the earliest waiting request has waited longer than `delay_factor * last_prompt_latency`, or when there are no running decode requests. + +### 1.4 Preemption + +When a decode step cannot extend a sequence's KV cache (no free blocks), the scheduler preempts the **last** running sequence: + +```python +def preempt(self, seq: Sequence): + seq.status = SequenceStatus.WAITING + self.block_manager.deallocate(seq) + self.waiting.appendleft(seq) +``` + +The preempted sequence is pushed to the front of the waiting queue and its blocks are fully deallocated, so it will be re-prefilled on the next scheduling cycle. + +--- + +## 2. ScheduledBatch Structure + +`ScheduledBatch` is constructed by `Scheduler.schedule()` and passed to the model runner. It is a frozen snapshot of batch metadata. + +### 2.1 Constructor Signature + +```python +class ScheduledBatch: + def __init__( + self, + seqs: dict[int, Sequence], + num_scheduled_tokens: list[int], + total_tokens_num: int, + total_tokens_num_prefill: int = 0, + total_tokens_num_decode: int = 0, + total_seqs_num: int = 0, + total_seqs_num_prefill: int = 0, + total_seqs_num_decode: int = 0, + is_dummy_run: bool = False, + num_spec_step: int = 0, + scheduled_spec_decode_tokens: dict[int, list[int]] = {}, + ): +``` + +### 2.2 Fields + +| Field | Type | Description | +|---|---|---| +| `req_ids` | `list[int]` | Sequence IDs in batch order (`list(seqs.keys())`) | +| `scheduled_tokens` | `list[list[int]]` | Last `num_tokens` token IDs per sequence (the tokens to process) | +| `temperatures` | `list[float]` | Sampling temperature per sequence | +| `context_lens` | `list[int]` | Total token count per sequence (`seq.num_tokens`) | +| `block_tables` | `list[list[int]]` | Block ID tables for sequences that have block tables | +| `last_block_num_tokens` | `list[int]` | Number of valid tokens in each sequence's last block | +| `num_cached_tokens` | `list[int]` | Number of tokens served from prefix cache per sequence | +| `num_scheduled_tokens` | `list[int]` | Number of new tokens scheduled per sequence | +| `total_tokens_num` | `int` | Sum of all scheduled tokens across all sequences | +| `total_tokens_num_prefill` | `int` | Total scheduled tokens for prefill sequences | +| `total_tokens_num_decode` | `int` | Total scheduled tokens for decode sequences | +| `total_seqs_num` | `int` | Total number of sequences in the batch | +| `total_seqs_num_prefill` | `int` | Number of prefill sequences | +| `total_seqs_num_decode` | `int` | Number of decode sequences | +| `is_dummy_run` | `bool` | Whether this is a dummy/warmup run | +| `num_spec_step` | `int` | Number of speculative decode steps (`mtp_k`) | +| `scheduled_spec_decode_tokens` | `dict[int, list[int]]` | Draft token IDs per sequence ID from prior speculative step | + +### 2.3 ScheduledBatchOutput + +Returned by the model runner after a forward pass: + +```python +class ScheduledBatchOutput: + def __init__( + self, + token_ids: dict[int, tuple[int, ...]], + draft_token_ids, + ): + self.req_ids = list(token_ids.keys()) + self.token_ids = token_ids # {seq_id: (accepted_token_ids...)} + self.draft_token_ids = draft_token_ids # {seq_id: [draft_ids]} or None +``` + +- `token_ids` maps sequence ID to a tuple of accepted token IDs. +- `draft_token_ids` maps sequence ID to a list of speculative draft token IDs for the next step (when MTP is active). +- A special key `-1` in `token_ids` signals deferred output mode. + +--- + +## 3. Block Manager + +The `BlockManager` implements paged KV cache management with fixed-size blocks. + +### 3.1 Block Class + +```python +class Block: + def __init__(self, block_id): + self.block_id = block_id # Unique integer ID + self.ref_count = 0 # Number of sequences referencing this block + self.hash = -1 # xxhash64 digest for prefix caching (-1 = unhashed) + self.token_ids = [] # Token IDs stored in this block +``` + +Methods: +- `update(hash, token_ids)` -- Sets the block's hash and token content. +- `reset()` -- Sets `ref_count = 1`, `hash = -1`, `token_ids = []` (used on fresh allocation). + +### 3.2 BlockManager Initialization + +```python +class BlockManager: + def __init__(self, config: Config): + block_size = config.kv_cache_block_size # Tokens per block (default 16) + num_blocks = config.num_kvcache_blocks # Total blocks in pool + self.block_size = block_size + self.blocks: list[Block] = [Block(i) for i in range(num_blocks)] + self.hash_to_block_id: dict[int, int] = dict() + self.free_block_ids: deque[int] = deque(range(num_blocks)) + self.used_block_ids: set[int] = set() + self.enable_prefix_caching = config.enable_prefix_caching +``` + +The block pool is pre-allocated at startup. `free_block_ids` is a deque for O(1) pop/push, `used_block_ids` tracks active blocks, and `hash_to_block_id` maps content hashes to block IDs for prefix caching. + +### 3.3 Allocation (`allocate`) + +Called during prefill scheduling for new sequences: + +```python +def allocate(self, seq: Sequence): +``` + +1. Iterates over `seq.num_blocks` blocks. +2. For each block, computes hash if the block is full (`len(token_ids) == block_size`). Partial (last) blocks get `hash = -1`. +3. If prefix caching is enabled, looks up `hash_to_block_id`: + - **Cache hit:** Verifies `token_ids` match. If the block is already in `used_block_ids`, increments `ref_count`. If it was evicted but still in the free list, re-allocates it. Increments `seq.num_cached_tokens` by `block_size`. + - **Cache miss:** Allocates from `free_block_ids[0]`. +4. Full blocks are registered in `hash_to_block_id`. + +### 3.4 Deallocation (`deallocate`) + +Called when a sequence finishes or is preempted: + +```python +def deallocate(self, seq: Sequence): + for block_id in reversed(seq.block_table): + block = self.blocks[block_id] + block.ref_count -= 1 + if block.ref_count == 0: + self._deallocate_block(block_id) + seq.num_cached_tokens = 0 + seq.block_table.clear() +``` + +Blocks are released in reverse order. Shared blocks (with `ref_count > 1` from prefix caching) are not freed until all referencing sequences release them. + +### 3.5 Can-Allocate and Can-Append Checks + +```python +def can_allocate(self, seq: Sequence) -> bool: + return len(self.free_block_ids) >= seq.num_blocks + +def can_append(self, seq: Sequence) -> bool: + return len(self.free_block_ids) >= (len(seq) % self.block_size == 1) +``` + +- `can_allocate` checks that enough free blocks exist for the full sequence. +- `can_append` checks whether a decode step needs a new block. A new block is needed only when `len(seq) % block_size == 1` (the previous block just filled up), requiring exactly 1 free block. + +### 3.6 May-Append (Decode Extension) + +```python +def may_append(self, seq: Sequence, num_new_tokens: int = 1): +``` + +Called during decode scheduling to extend a sequence's block table: + +1. If the sequence length modulo `block_size` falls within `(0, num_new_tokens]`, or `block_size == 1`, a new block is needed: + - Allocates from `free_block_ids` and appends to `block_table`. + - For `block_size == 1`, immediately computes and stores the hash. +2. If `seq_len % block_size == 0`, the last block is now full -- computes and stores its hash using the chained prefix. +3. Otherwise the last block is partially filled with `hash = -1` (hash deferred until full). + +--- + +## 4. Prefix Caching + +Prefix caching enables sharing KV cache blocks across sequences that share a common prompt prefix, avoiding redundant computation. + +### 4.1 Hash Function + +ATOM uses `xxhash64` (via the `xxhash` Python library) for fast, collision-resistant block hashing: + +```python +@classmethod +def compute_hash(cls, token_ids: list[int], prefix: int = -1): + h = xxhash.xxh64() + if prefix != -1: + h.update(prefix.to_bytes(8, "little")) + h.update(np.array(token_ids).tobytes()) + return h.intdigest() +``` + +### 4.2 Hash Chaining + +Blocks form a hash chain: each block's hash incorporates the previous block's hash as a prefix. This ensures that two blocks with identical token content but different preceding context produce different hashes. + +- First block: `compute_hash(token_ids, prefix=-1)` (no prefix). +- Subsequent blocks: `compute_hash(token_ids, prefix=prev_block.hash)`. +- Only **full** blocks (where `len(token_ids) == block_size`) receive a hash. Partial blocks have `hash = -1` and are not cached. + +### 4.3 Cache Lookup During Allocation + +During `allocate()`, for each full block: + +1. Compute the block hash via the chain. +2. Look up `hash_to_block_id.get(h, -1)`. +3. If found, verify `self.blocks[block_id].token_ids == token_ids` (guard against hash collisions). +4. **Hit:** Reuse the block. If already in `used_block_ids`, increment `ref_count`. Add `block_size` to `seq.num_cached_tokens`. +5. **Miss (or first miss in chain):** Once a cache miss occurs, all subsequent blocks in the sequence are also misses (`cache_miss = True` is sticky). Allocate fresh blocks from the free list. + +### 4.4 Reference Counting + +- On allocation: `block.reset()` sets `ref_count = 1`. +- On cache hit for an in-use block: `ref_count += 1`. +- On deallocation: `ref_count -= 1`. Block returns to free list only when `ref_count == 0`. +- Shared blocks (prefix cache hits) have `ref_count > 1`. + +### 4.5 Enabling Prefix Caching + +Set `enable_prefix_caching=True` in `Config`. When disabled, the hash lookup in `allocate()` is skipped entirely (`block_id` is always `-1`). + +--- + +## 5. Postprocessing + +`Scheduler.postprocess()` is called after the model forward pass to update sequences with sampled tokens, check stop conditions, generate streaming output, and clean up finished sequences. + +### 5.1 Signature + +```python +def postprocess( + self, + seqs: list[Sequence], + fwd_output: ScheduledBatchOutput, + stream_output_queue=None, +) -> list[Sequence]: +``` + +### 5.2 Token Appending + +For each running sequence whose ID appears in `fwd_output.req_ids`: + +- **Deferred output or speculative decode with EOS:** Replaces placeholder tokens in-place: + ```python + seq.token_ids[-num_placeholder:] = token_ids + seq.output_tokens[-num_placeholder:] = token_ids + ``` +- **Normal path:** Calls `seq.append_token(token_id)` for each accepted token, which appends to `token_ids`, updates `output_tokens`, `last_token`, and `num_tokens`. + +### 5.3 Stop Condition Checking + +The postprocessor checks stop conditions in priority order: + +1. **Stop token sequences:** Compares the tail of `seq.token_ids` against each entry in `seq.stop_token_sequences`. Also checks the MTP-adjusted position for speculative decode. Sets `leave_reason = "stop_sequence"`. +2. **EOS token:** If `self.eos_token_id` appears in the accepted tokens and `seq.ignore_eos` is `False`. Sets `leave_reason = "eos"`. +3. **Stop token IDs:** If any accepted token is in `self.stop_token_ids` (from `Config.stop_token_ids`, derived from the model's generation config). Sets `leave_reason = "stop_{token_id}"`. +4. **Max tokens:** If `seq.num_completion_tokens >= seq.max_tokens`. Sets `leave_reason = "max_tokens"`. + +### 5.4 Stream Output + +When `stream_output_queue` is provided, the scheduler creates a `RequestOutput` for each processed sequence: + +```python +request_output = RequestOutput( + request_id=seq.id, + output_tokens=output_tokens_list, + finished=(leave_reason is not None), + finish_reason=leave_reason, +) +``` + +`RequestOutput` fields: + +| Field | Type | Description | +|---|---|---| +| `request_id` | `int` | Sequence ID | +| `output_tokens` | `list[int]` | Newly generated tokens since last callback | +| `finished` | `bool` | Whether the sequence is done | +| `finish_reason` | `Optional[str]` | One of: `"eos"`, `"max_tokens"`, `"stop_sequence"`, `"stop_{token_id}"`, or `None` | + +Stream outputs are batched and put onto `stream_output_queue` via `put_nowait`. + +### 5.5 Sequence Cleanup + +For finished sequences: +1. Set `seq.status = SequenceStatus.FINISHED`. +2. Call `block_manager.deallocate(seq)` to free KV cache blocks. +3. Remove from the `running` deque. +4. Return in the `finished_seqs` list. + +### 5.6 Placeholder Insertion + +When speculative decoding or deferred output is active, placeholder EOS tokens are appended to still-running sequences to reserve KV cache slots for the next step: + +```python +if need_placeholder: + for seq in seqs: + if seq.status == SequenceStatus.RUNNING: + for _ in range(seq.num_placeholder): + seq.append_token(self.eos_token_id) +``` + +The placeholder count is determined as follows: + +- **For sequences processed in this step** (had output in `fwd_output`): always `1 + mtp_k`, regardless of mode. +- **For sequences not processed** (skipped in this step): the count depends on the batch-level mode: + - Deferred output + speculative: `mtp_k + 1` + - Deferred output only: `1` + - Speculative only: `mtp_k` + +--- + +## 6. Speculative Decoding Integration + +ATOM supports Multi-Token Prediction (MTP) speculative decoding, where a draft model proposes `mtp_k` additional tokens per step. + +### 6.1 Scheduler Tracking + +```python +self.use_spec = config.speculative_config is not None +self.mtp_k: int = config.speculative_config.num_speculative_tokens if self.use_spec else 0 +self.total_draft_tokens = 0 +self.total_accepted_tokens = 0 +``` + +Note: `SpeculativeConfig` currently enforces `num_speculative_tokens == 1`. + +### 6.2 Draft Tokens in Scheduling + +During decode scheduling: +- If `seq.spec_token_ids` is non-empty, the draft tokens are recorded in `scheduled_spec_decode_tokens[seq.id]`. +- `num_new_tokens = mtp_k + 1` (1 target + `mtp_k` draft tokens), so `may_append` reserves enough block space. +- The `ScheduledBatch` carries `num_spec_step = mtp_k` and the `scheduled_spec_decode_tokens` dict. + +### 6.3 Acceptance Statistics + +```python +def update_spec_stats(self, num_accepted_tokens): + self.total_draft_tokens += self.mtp_k + self.total_accepted_tokens += num_accepted_tokens - self.mtp_k +``` + +Every 1000 draft tokens, the acceptance rate is logged: + +``` +[MTP Stats] Total draft tokens: 5000, Accepted: 3750, Acceptance rate: 75.00% +``` + +### 6.4 Draft Token Storage on Sequences + +After postprocessing, accepted draft token IDs for the next step are stored on the sequence: + +```python +if draft_token_ids and seq.id in draft_token_ids: + seq.spec_token_ids = draft_token_ids[seq.id] +``` + +These are picked up by the scheduler on the next `schedule()` call. + +--- + +## 7. Sequence Management + +The `Sequence` class represents a single request throughout its lifecycle. + +### 7.1 Constructor + +```python +class Sequence: + def __init__( + self, + token_ids: list[int], + block_size: int, + sampling_params=SamplingParams(), + stop_token_sequences: list[list[int]] = None, + stream_callback: Optional[Callable[[Any], None]] = None, + id=None, + ): +``` + +### 7.2 Core Fields + +| Field | Type | Description | +|---|---|---| +| `id` | `int` | Auto-incrementing unique ID (from `itertools.count`) | +| `token_ids` | `list[int]` | Full token sequence (prompt + completion) | +| `block_size` | `int` | KV cache block size (from config) | +| `status` | `SequenceStatus` | Current lifecycle state | +| `type` | `SequenceType` | Current step type (`DUMMY`, `PREFILL`, `DECODE`) | +| `num_tokens` | `int` | Total tokens (prompt + completion); property with setter that also updates `num_blocks` and `last_block_num_tokens` | +| `num_prompt_tokens` | `int` | Number of prompt tokens (fixed at init) | +| `num_cached_tokens` | `int` | Tokens served from prefix cache | +| `block_table` | `list[int]` | Ordered list of block IDs assigned to this sequence | +| `last_token` | `int` | Most recently appended token ID | +| `temperature` | `float` | Sampling temperature (from `SamplingParams`) | +| `max_tokens` | `int` | Max completion tokens (from `SamplingParams`, default 64) | +| `ignore_eos` | `bool` | Whether to ignore EOS tokens (from `SamplingParams`) | +| `stop_strings` | `Optional[list[str]]` | Stop strings (from `SamplingParams`) | +| `stop_token_sequences` | `list[list[int]]` | Token-level stop sequences | +| `stream_callback` | `Optional[Callable]` | Per-sequence stream callback | +| `output_tokens` | `list[int]` | Cache of newly generated tokens | +| `spec_token_ids` | `list[int]` | Speculative draft token IDs for next step | +| `num_placeholder` | `int` | Number of placeholder tokens inserted for speculative/deferred output | + +### 7.3 Timing Fields + +| Field | Type | Description | +|---|---|---| +| `arrive_time` | `float` | Timestamp when the sequence entered the scheduler | +| `first_token_time` | `float` | Timestamp of the first completion token (TTFT measurement) | +| `leave_time` | `float` | Timestamp when the sequence finished | +| `leave_reason` | `str` | Reason for finishing (e.g., `"eos"`, `"max_tokens"`, `"stop_sequence"`) | + +### 7.4 Computed Properties + +| Property | Returns | +|---|---| +| `num_completion_tokens` | `num_tokens - num_prompt_tokens` | +| `prompt_token_ids` | `token_ids[:num_prompt_tokens]` | +| `completion_token_ids` | `token_ids[num_prompt_tokens:]` | +| `num_cached_blocks` | `num_cached_tokens // block_size` | +| `is_finished` | `status == SequenceStatus.FINISHED` | + +### 7.5 num_tokens Setter + +Setting `num_tokens` triggers derived field updates: + +```python +@num_tokens.setter +def num_tokens(self, value): + self._num_tokens = value + self.num_blocks = (value + self.block_size - 1) // self.block_size + self.last_block_num_tokens = self._num_tokens - (self.num_blocks - 1) * self.block_size +``` + +### 7.6 Lifecycle + +``` + allocate blocks + add(seq) ---------> WAITING ---------> RUNNING (PREFILL) + ^ | + | | next schedule() step + preempt() v + | RUNNING (DECODE) <--+ + +--- can't append | | + | stop condition met + v + FINISHED + | + | deallocate blocks + v + (removed from running) +``` + +### 7.7 SequenceStatus Enum + +| Value | Meaning | +|---|---| +| `WAITING` | In the waiting queue, pending prefill | +| `RUNNING` | Actively being processed (prefill or decode) | +| `FINISHED` | Stop condition met, blocks deallocated | +| `EXIT_ENGINE` | Sentinel for engine shutdown | + +### 7.8 SequenceType Enum + +| Value | Meaning | +|---|---| +| `DUMMY` | Initial state before scheduling | +| `PREFILL` | Currently in prefill phase | +| `DECODE` | Currently in decode phase | + +--- + +## Source Files + +| File | Description | +|---|---| +| `atom/model_engine/scheduler.py` | `Scheduler`, `ScheduledBatch`, `ScheduledBatchOutput` -- scheduling algorithm, postprocessing, speculative decode stats | +| `atom/model_engine/block_manager.py` | `Block`, `BlockManager` -- paged KV cache block pool, allocation/deallocation, prefix caching with xxhash64 | +| `atom/model_engine/sequence.py` | `Sequence`, `SequenceStatus`, `SequenceType` -- request lifecycle, token management, timing | +| `atom/model_engine/request.py` | `RequestOutput` -- streaming output dataclass with `request_id`, `output_tokens`, `finished`, `finish_reason` | +| `atom/config.py` | `Config` -- scheduling-related fields (`max_num_seqs`, `max_num_batched_tokens`, `kv_cache_block_size`, `enable_prefix_caching`, `scheduler_delay_factor`), `SpeculativeConfig` | +| `atom/sampling_params.py` | `SamplingParams` -- `temperature`, `max_tokens`, `ignore_eos`, `stop_strings` | diff --git a/docs/serving_benchmarking_guide.md b/docs/serving_benchmarking_guide.md new file mode 100644 index 000000000..8df5f813f --- /dev/null +++ b/docs/serving_benchmarking_guide.md @@ -0,0 +1,648 @@ +# ATOM Serving & Benchmarking Guide + +ATOM (AiTer Optimized Model) is AMD's lightweight LLM inference engine built on +[AITER](https://github.com/ROCm/aiter) kernels for ROCm/HIP GPUs. This guide +covers the OpenAI-compatible serving API, programmatic engine usage, benchmarking +tools, profiling, and speculative decoding. + +--- + +## Quick Reference + +```bash +# Start the OpenAI-compatible server +python -m atom.entrypoints.openai_server --model --kv_cache_dtype fp8 + +# Run the online serving benchmark +python -m atom.benchmarks.benchmark_serving \ + --backend vllm --model \ + --base-url http://localhost:8000 \ + --dataset-name random --random-input-len 1024 --random-output-len 128 \ + --num-prompts 1000 --request-rate inf --ignore-eos + +# Simple inference example +python -m atom.examples.simple_inference --model --kv_cache_dtype fp8 + +# Offline profiling +python -m atom.examples.profile_offline --model --kv_cache_dtype fp8 + +# Accuracy validation with lm-eval +lm_eval --model local-completions \ + --model_args model=,base_url=http://localhost:8000/v1/completions,num_concurrent=64,max_retries=3,tokenized_requests=False \ + --tasks gsm8k --num_fewshot 5 +``` + +--- + +## 1. OpenAI-Compatible Server + +The server is implemented in `atom/entrypoints/openai_server.py` using FastAPI +and Uvicorn. It exposes OpenAI-compatible HTTP endpoints so that existing +clients (curl, OpenAI SDK, lm-eval) work without modification. + +### 1.1 Endpoints + +| Method | Path | Description | +|--------|------|-------------| +| `POST` | `/v1/chat/completions` | Chat completion (ChatCompletionRequest -> ChatCompletionResponse) | +| `POST` | `/v1/completions` | Text completion (CompletionRequest -> CompletionResponse) | +| `GET` | `/v1/models` | List available models | +| `GET` | `/health` | Health check (returns `{"status": "ok"}`) | +| `POST` | `/start_profile` | Start torch profiler on the engine | +| `POST` | `/stop_profile` | Stop torch profiler and flush traces | + +### 1.2 Request Models + +**ChatCompletionRequest** fields: + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `model` | `Optional[str]` | `None` | Model name (validated against the loaded model) | +| `messages` | `Optional[List[ChatMessage]]` | `None` | List of chat messages (`role`, `content`) | +| `prompt` | `Optional[List[ChatMessage]]` | `None` | Alias for `messages` | +| `temperature` | `Optional[float]` | `1.0` | Sampling temperature | +| `top_p` | `Optional[float]` | `1.0` | Nucleus sampling threshold | +| `max_tokens` | `Optional[int]` | `256` | Maximum tokens to generate | +| `stop` | `Optional[List[str]]` | `None` | Stop strings | +| `ignore_eos` | `Optional[bool]` | `False` | Ignore end-of-sequence token | +| `stream` | `Optional[bool]` | `False` | Enable server-sent events streaming | +| `seed` | `Optional[int]` | `None` | Random seed | + +**CompletionRequest** fields: + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `model` | `Optional[str]` | `None` | Model name | +| `prompt` | `str` | (required) | Text prompt | +| `temperature` | `Optional[float]` | `1.0` | Sampling temperature | +| `top_p` | `Optional[float]` | `1.0` | Nucleus sampling threshold | +| `max_tokens` | `Optional[int]` | `256` | Maximum tokens to generate | +| `stop` | `Optional[List[str]]` | `None` | Stop strings | +| `ignore_eos` | `Optional[bool]` | `False` | Ignore end-of-sequence token | +| `stream` | `Optional[bool]` | `False` | Enable SSE streaming | + +### 1.3 Response Models + +Both `ChatCompletionResponse` and `CompletionResponse` include: + +- `id` -- unique request identifier (e.g. `chatcmpl-` or `cmpl-`) +- `object` -- `"chat.completion"` or `"text_completion"` +- `created` -- Unix timestamp +- `model` -- model name +- `choices` -- list of generated completions +- `usage` -- token counts (`prompt_tokens`, `completion_tokens`, `total_tokens`) + plus `ttft_s`, `tpot_s`, and `latency_s` timing fields + +Streaming responses use the SSE (Server-Sent Events) protocol with +`data: [DONE]\n\n` as the termination signal. + +### 1.4 Server Startup + +```bash +python -m atom.entrypoints.openai_server \ + --model \ + --kv_cache_dtype fp8 \ + --host 0.0.0.0 \ + --server-port 8000 +``` + +Server-specific CLI arguments: + +| Argument | Default | Description | +|----------|---------|-------------| +| `--host` | `0.0.0.0` | Bind address | +| `--server-port` | `8000` | HTTP port (note: `--port` is for internal engine communication) | + +All `EngineArgs` arguments are also accepted (see Section 7 for the full list). + +### 1.5 Example: curl + +```bash +# Non-streaming chat completion +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "deepseek-ai/DeepSeek-R1", + "messages": [{"role": "user", "content": "Hello!"}], + "max_tokens": 128 + }' + +# Streaming text completion +curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "The capital of France is", + "max_tokens": 64, + "stream": true + }' +``` + +--- + +## 2. Programmatic API (LLMEngine) + +The `LLMEngine` class in `atom/model_engine/llm_engine.py` provides a +Python-native interface for inference without running an HTTP server. + +### 2.1 Initialization + +```python +from atom import LLMEngine, SamplingParams + +engine = LLMEngine(model="deepseek-ai/DeepSeek-R1", kv_cache_dtype="fp8", + tensor_parallel_size=8) +``` + +`LLMEngine.__init__(model, **kwargs)` accepts all `Config` field names as +keyword arguments (e.g. `tensor_parallel_size`, `kv_cache_dtype`, +`max_model_len`, `data_parallel_size`, `gpu_memory_utilization`). + +### 2.2 SamplingParams + +Defined in `atom/sampling_params.py`: + +```python +@dataclass +class SamplingParams: + temperature: float = 1.0 + max_tokens: int = 64 + ignore_eos: bool = False + stop_strings: Optional[list[str]] = None +``` + +### 2.3 Core Methods + +| Method | Signature | Description | +|--------|-----------|-------------| +| `generate` | `(prompts: list[str], sampling_params) -> list[dict]` | Synchronous batch generation; blocks until all prompts complete | +| `add_request` | `(prompt_or_tokens_list, sampling_params_list, stream_callback=None)` | Submit requests for asynchronous processing | +| `step` | `() -> list[Sequence]` | Retrieve completed sequences | +| `is_finished` | `() -> bool` | Check whether all pending requests have completed | +| `start_profile` | `()` | Start torch profiler on all workers | +| `stop_profile` | `()` | Stop torch profiler and write traces | +| `print_mtp_statistics` | `()` | Print speculative decoding acceptance statistics | + +### 2.4 Synchronous Generation Example + +```python +from atom import LLMEngine, SamplingParams + +engine = LLMEngine(model="meta-llama/Meta-Llama-3-8B", kv_cache_dtype="fp8") +params = SamplingParams(temperature=0.6, max_tokens=256) + +outputs = engine.generate(["Explain quantum computing in simple terms."], params) +for out in outputs: + print(out["text"]) +``` + +Each output dictionary contains: `text`, `token_ids`, `latency`, +`finish_reason`, `num_tokens_input`, `num_tokens_output`, `ttft`, and `tpot`. + +### 2.5 Asynchronous / Streaming Usage + +```python +engine.add_request( + prompt_or_tokens_list=["Hello world", "How are you?"], + sampling_params_list=SamplingParams(temperature=0.8, max_tokens=128), + stream_callback=my_callback, # called per-token with RequestOutput +) + +while not engine.is_finished(): + completed = engine.step() + # process completed sequences +``` + +--- + +## 3. Simple Inference + +The `atom/examples/simple_inference.py` script provides a quick way to validate +model loading and generation. + +### 3.1 Usage + +```bash +python -m atom.examples.simple_inference \ + --model meta-llama/Meta-Llama-3-8B \ + --kv_cache_dtype fp8 \ + --temperature 0.6 +``` + +### 3.2 What It Does + +1. Parses all `EngineArgs` plus `--temperature` (default `0.6`). +2. Creates an `LLMEngine` via `EngineArgs.from_cli_args(args).create_engine()`. +3. Applies the model's chat template to four built-in prompts (English and + Chinese) with `enable_thinking=True`. +4. Runs a warmup generation, then generates completions for the batch. +5. Calls `llm.print_mtp_statistics()` to report speculative decoding stats + (if MTP is enabled). + +--- + +## 4. Benchmarking + +ATOM ships a comprehensive online serving benchmark in +`atom/benchmarks/benchmark_serving.py` (adapted from vLLM's benchmarking +tooling). + +### 4.1 Metrics + +The `BenchmarkMetrics` dataclass tracks: + +| Metric | Abbreviation | Description | +|--------|--------------|-------------| +| Time to First Token | **TTFT** | Latency from request submission to the first generated token | +| Time per Output Token | **TPOT** | Average latency per output token (excluding the first) | +| Inter-Token Latency | **ITL** | Latency between successive output tokens | +| End-to-End Latency | **E2EL** | Total latency from request send to full response receipt | +| Request Throughput | -- | Completed requests per second | +| Output Token Throughput | -- | Generated tokens per second | +| Total Token Throughput | -- | (input + output) tokens per second | +| Request Goodput | -- | Requests per second meeting SLO targets | + +For each latency metric, mean, median, standard deviation, and configurable +percentiles (default: P99) are reported. + +### 4.2 Key CLI Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--backend` | `vllm` | Backend type. Choices: `tgi`, `vllm`, `lmdeploy`, `deepspeed-mii`, `openai`, `openai-chat`, `tensorrt-llm`, `scalellm`, `sglang` | +| `--model` | (required) | Model name or path | +| `--base-url` | `None` | Server base URL (e.g. `http://localhost:8000`) | +| `--host` | `127.0.0.1` | Server host (used when `--base-url` is not set) | +| `--port` | `8000` | Server port (used when `--base-url` is not set) | +| `--endpoint` | `/v1/completions` | API endpoint path | +| `--dataset-name` | `sharegpt` | Dataset type: `sharegpt`, `burstgpt`, `sonnet`, `random`, `hf` | +| `--dataset-path` | `None` | Path to dataset file or HuggingFace dataset ID | +| `--num-prompts` | `1000` | Number of prompts to benchmark | +| `--request-rate` | `inf` | Requests per second (`inf` = send all at once) | +| `--burstiness` | `1.0` | Burstiness factor (1.0 = Poisson process) | +| `--max-concurrency` | `None` | Maximum concurrent requests | +| `--ignore-eos` | `False` | Ignore EOS token in generation | +| `--save-result` | `False` | Save results to JSON | +| `--result-dir` | `None` | Directory for result JSON files | +| `--result-filename` | `None` | Custom filename for results | +| `--percentile-metrics` | `ttft,tpot,itl` | Comma-separated metrics to report percentiles for | +| `--metric-percentiles` | `99` | Comma-separated percentile values (e.g. `25,50,75,99`) | +| `--goodput` | `None` | SLO targets as `KEY:VALUE` pairs (e.g. `ttft:100 tpot:50`) | +| `--profile` | `False` | Enable torch profiler during the benchmark run | +| `--tokenizer` | `None` | Custom tokenizer name or path | +| `--seed` | `0` | Random seed | + +**Random dataset options:** + +| Argument | Default | Description | +|----------|---------|-------------| +| `--random-input-len` | `1024` | Input token length | +| `--random-output-len` | `128` | Output token length | +| `--random-range-ratio` | `1.0` | Length variation ratio | +| `--random-prefix-len` | `0` | Fixed prefix token length | +| `--use-chat-template` | `False` | Apply chat template to random prompts | + +### 4.3 Backend Request Functions + +Defined in `atom/benchmarks/backend_request_func.py`: + +| Backend Key | Function | Protocol | +|-------------|----------|----------| +| `vllm` | `async_request_openai_completions` | OpenAI Completions API (streaming) | +| `openai` | `async_request_openai_completions` | OpenAI Completions API (streaming) | +| `openai-chat` | `async_request_openai_chat_completions` | OpenAI Chat Completions API (streaming) | +| `tgi` | `async_request_tgi` | TGI `generate_stream` | +| `tensorrt-llm` | `async_request_trt_llm` | TRT-LLM `generate_stream` | +| `deepspeed-mii` | `async_request_deepspeed_mii` | DeepSpeed-MII | +| `lmdeploy` | `async_request_openai_completions` | OpenAI Completions API | +| `scalellm` | `async_request_openai_completions` | OpenAI Completions API | +| `sglang` | `async_request_openai_completions` | OpenAI Completions API | + +Each function uses `RequestFuncInput` and returns a `RequestFuncOutput` with +timing data (`ttft`, `itl`, `latency`, `tpot`). + +### 4.4 Full Benchmark Example + +```bash +# 1. Start the server +python -m atom.entrypoints.openai_server \ + --kv_cache_dtype fp8 -tp 8 --model deepseek-ai/DeepSeek-R1 + +# 2. Run benchmark +MODEL=deepseek-ai/DeepSeek-R1 +ISL=1024 +OSL=1024 +CONC=128 +PORT=8000 +RESULT_FILENAME=Deepseek-R1-result + +python -m atom.benchmarks.benchmark_serving \ + --model=$MODEL --backend=vllm --base-url=http://localhost:$PORT \ + --dataset-name=random \ + --random-input-len=$ISL --random-output-len=$OSL \ + --random-range-ratio 0.8 \ + --num-prompts=$(( $CONC * 10 )) \ + --max-concurrency=$CONC \ + --request-rate=inf --ignore-eos \ + --save-result --percentile-metrics="ttft,tpot,itl,e2el" \ + --result-dir=./ --result-filename=$RESULT_FILENAME.json +``` + +--- + +## 5. Profiling + +ATOM supports PyTorch profiling via environment variables, HTTP endpoints, and +the programmatic API. + +### 5.1 Configuration + +| Mechanism | Description | +|-----------|-------------| +| `--torch-profiler-dir ` | CLI arg to set the trace output directory | +| `ATOM_TORCH_PROFILER_DIR` env var | Sets the default `torch_profiler_dir` in `Config` | +| `ATOM_PROFILER_MORE=1` env var | Enables detailed profiling: `record_shapes`, `with_stack`, `profile_memory` | + +When a profiler directory is configured, each worker saves traces to a +rank-specific subdirectory: + +- Multi-GPU with DP: `{profiler_dir}/dp{dp_rank}_tp{rank}/` +- Single-GPU / TP-only: `{profiler_dir}/rank_{rank}/` + +Traces are saved in gzip-compressed TensorBoard format and can be viewed with +`tensorboard --logdir ` or Chrome's `chrome://tracing`. + +### 5.2 Online Profiling (HTTP) + +While the server is running, start and stop profiling with HTTP requests: + +```bash +# Start profiling +curl -s -S -X POST http://127.0.0.1:8000/start_profile + +# ... run your workload ... + +# Stop profiling and flush traces +curl -s -S -X POST http://127.0.0.1:8000/stop_profile +``` + +The server must be started with `--torch-profiler-dir` or with +`ATOM_TORCH_PROFILER_DIR` set for these endpoints to produce traces. + +### 5.3 Programmatic Profiling + +```python +engine = LLMEngine(model="Qwen/Qwen3-0.6B", torch_profiler_dir="./traces") + +engine.start_profile() +outputs = engine.generate(prompts, sampling_params) +engine.stop_profile() +# Traces written to ./traces/rank_0/ +``` + +### 5.4 Offline Profiling Script + +`atom/examples/profile_offline.py` provides a self-contained offline profiling +workflow: + +```bash +python -m atom.examples.profile_offline \ + --model Qwen/Qwen3-0.6B \ + --kv_cache_dtype fp8 \ + --torch-profiler-dir ./profiler_traces \ + --input-length 128 \ + --output-length 32 \ + --bs 4 +``` + +Script-specific arguments: + +| Argument | Default | Description | +|----------|---------|-------------| +| `--input-length` | `128` | Approximate input prompt length in tokens | +| `--output-length` | `32` | Output generation length in tokens | +| `--bs` | `1` | Batch size (number of parallel requests) | +| `--random-input` | `False` | Use random token input instead of predefined text | + +If `--torch-profiler-dir` is not specified, the script defaults to +`./profiler_traces`. + +### 5.5 Profiling During Benchmarks + +The benchmark tool can trigger profiling automatically via `--profile`: + +```bash +python -m atom.benchmarks.benchmark_serving \ + --model --backend vllm \ + --base-url http://localhost:8000 \ + --dataset-name random --num-prompts 100 \ + --profile +``` + +This sends `POST /start_profile` before the benchmark and +`POST /stop_profile` after completion. + +--- + +## 6. Speculative Decoding (MTP) + +ATOM supports Multi-Token Prediction (MTP) for DeepSeek models using the +Eagle-style speculative decoding framework. + +### 6.1 Architecture + +- **EagleProposer** (`atom/spec_decode/eagle.py`): Loads and runs the draft + (MTP) model to propose speculative tokens. Supports the `DeepSeekMTPModel` + architecture via `DeepSeekMTP`. +- **RejectionSampler** (`atom/model_ops/rejection_sampler.py`): Implements + greedy rejection sampling with a Triton kernel. Compares draft token IDs + against target model argmax and accepts matching prefixes; appends a bonus + token if all drafts are accepted. + +### 6.2 Configuration + +Enable MTP via CLI arguments: + +```bash +python -m atom.entrypoints.openai_server \ + --model deepseek-ai/DeepSeek-R1 \ + --kv_cache_dtype fp8 -tp 8 \ + --method mtp \ + --num-speculative-tokens 1 +``` + +| Argument | Default | Description | +|----------|---------|-------------| +| `--method` | `None` | Speculative method; currently only `mtp` is supported | +| `--num-speculative-tokens` | `1` | Number of draft tokens per iteration (draft model runs this many autoregressive steps) | + +### 6.3 MTP Statistics + +ATOM tracks acceptance statistics at runtime: + +- **total_draft_tokens**: Total number of draft tokens proposed +- **total_accepted_tokens**: Number of draft tokens accepted by rejection sampling +- **acceptance_rate**: Ratio of accepted to draft tokens + +Statistics are logged every 1000 draft tokens and can be printed on demand: + +```python +engine.print_mtp_statistics() +``` + +Example output: +``` +MTP Statistics: + Total draft tokens: 5000 + Accepted tokens: 4250 + Acceptance rate: 85.00% +``` + +### 6.4 How Rejection Sampling Works + +1. The draft model generates `num_speculative_tokens` token predictions + autoregressively using argmax. +2. The target model verifies all draft tokens in a single forward pass. +3. The `rejection_greedy_sample_kernel` (Triton) compares each draft token + against the target model's argmax: + - If they match, the token is accepted. + - On the first mismatch, the target model's token replaces it and all + subsequent draft tokens are discarded. + - If all draft tokens match, a bonus token from the target model is + appended. + +--- + +## 7. Deployment Examples + +### 7.1 Single-GPU + +```bash +python -m atom.entrypoints.openai_server \ + --model Qwen/Qwen3-0.6B \ + --kv_cache_dtype fp8 +``` + +### 7.2 Multi-GPU with Tensor Parallelism + +```bash +python -m atom.entrypoints.openai_server \ + --model deepseek-ai/DeepSeek-R1 \ + --kv_cache_dtype fp8 \ + -tp 8 +``` + +### 7.3 Docker Deployment + +```bash +# Pull the ROCm PyTorch image +docker pull rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0 + +# Launch container +docker run -it --network=host \ + --device=/dev/kfd \ + --device=/dev/dri \ + --group-add video \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + -v $HOME:/home/$USER \ + -v /mnt:/mnt \ + -v /data:/data \ + --shm-size=16G \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0 + +# Inside the container +pip install amd-aiter +git clone https://github.com/ROCm/ATOM.git && cd ATOM && pip install . + +# Start serving +python -m atom.entrypoints.openai_server \ + --model deepseek-ai/DeepSeek-R1 \ + --kv_cache_dtype fp8 -tp 8 +``` + +### 7.4 Engine CLI Arguments (EngineArgs) + +These arguments are available for all entrypoints (server, examples, and any +script using `EngineArgs.add_cli_args`): + +| Argument | Default | Description | +|----------|---------|-------------| +| `--model` | `Qwen/Qwen3-0.6B` | Model name or path | +| `--trust-remote-code` | `False` | Trust remote code from HuggingFace | +| `--tensor-parallel-size`, `-tp` | `1` | Tensor parallel size | +| `--data-parallel-size`, `-dp` | `1` | Data parallel size | +| `--enforce-eager` | `False` | Disable CUDA graph capture; use eager execution | +| `--enable_prefix_caching` | `False` | Enable prefix caching | +| `--port` | `8006` | Internal engine communication port | +| `--kv_cache_dtype` | `bf16` | KV cache dtype: `bf16` or `fp8` | +| `--block-size` | `16` | KV cache block size | +| `--max-model-len` | `None` | Maximum context length (defaults to HF config) | +| `--max-num-batched-tokens` | `16384` | Maximum tokens per batch | +| `--max-num-seqs` | `512` | Maximum sequences per batch | +| `--gpu-memory-utilization` | `0.9` | GPU memory utilization (0.0 to 1.0) | +| `--scheduler-delay-factor` | `0.0` | Delay factor before scheduling next prompt | +| `--cudagraph-capture-sizes` | `[1,2,4,...,256]` | Batch sizes for CUDA graph capture | +| `--level` | `3` | Compilation level (0-3); 3 = torch.compile | +| `--load_dummy` | `False` | Skip loading model weights (for testing) | +| `--enable-expert-parallel` | `False` | Enable expert parallelism for MoE | +| `--enable-dp-attention` | `False` | Enable data-parallel attention | +| `--torch-profiler-dir` | `None` | Directory for torch profiler traces | +| `--method` | `None` | Speculative decoding method (`mtp`) | +| `--num-speculative-tokens` | `1` | Number of speculative tokens per step | + +--- + +## 8. Accuracy Validation + +ATOM supports accuracy validation through the +[lm-eval](https://github.com/EleutherAI/lm-evaluation-harness) framework via +the OpenAI-compatible API. + +### 8.1 Setup + +```bash +pip install lm-eval[api] +``` + +### 8.2 Run Evaluation + +Start an ATOM server, then run lm-eval against it: + +```bash +# Start server +python -m atom.entrypoints.openai_server \ + --model meta-llama/Meta-Llama-3-8B \ + --kv_cache_dtype fp8 + +# Run evaluation +lm_eval --model local-completions \ + --model_args model=meta-llama/Meta-Llama-3-8B,base_url=http://localhost:8000/v1/completions,num_concurrent=64,max_retries=3,tokenized_requests=False \ + --tasks gsm8k \ + --num_fewshot 5 +``` + +Any lm-eval task can be used. The `local-completions` model type sends +requests to the `/v1/completions` endpoint, making it compatible with the ATOM +server without modification. + +--- + +## Source Files + +| File | Description | +|------|-------------| +| `atom/entrypoints/openai_server.py` | OpenAI-compatible API server (FastAPI + Uvicorn) | +| `atom/model_engine/llm_engine.py` | `LLMEngine` programmatic API | +| `atom/sampling_params.py` | `SamplingParams` dataclass | +| `atom/model_engine/arg_utils.py` | `EngineArgs` CLI argument definitions and engine factory | +| `atom/examples/simple_inference.py` | Simple batch inference example | +| `atom/examples/profile_offline.py` | Offline profiling tool | +| `atom/benchmarks/benchmark_serving.py` | Online serving benchmark (`BenchmarkMetrics`, dataset sampling, result reporting) | +| `atom/benchmarks/backend_request_func.py` | Async HTTP request functions for each backend (`RequestFuncInput`, `RequestFuncOutput`, `ASYNC_REQUEST_FUNCS`) | +| `atom/benchmarks/benchmark_utils.py` | `convert_to_pytorch_benchmark_format` utility | +| `atom/spec_decode/eagle.py` | `EagleProposer` -- MTP draft model for DeepSeek speculative decoding | +| `atom/model_ops/rejection_sampler.py` | `RejectionSampler` with Triton greedy rejection kernel | +| `atom/config.py` | `Config`, `CompilationConfig`, `SpeculativeConfig` dataclasses | +| `atom/model_engine/model_runner.py` | `ModelRunner` with `start_profiler`/`stop_profiler` and MTP statistics | From 2c12a07b11a441c9d1bd66ef60fac6e668b511b8 Mon Sep 17 00:00:00 2001 From: Lingpeng Jin <103567126+valarLip@users.noreply.github.com> Date: Sun, 8 Feb 2026 16:28:08 +0800 Subject: [PATCH 2/8] Update README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 92139a4b2..e3d9ee2b0 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ | [DeepSeek V2/V3](https://huggingface.co/deepseek-ai) | `DeepseekV3ForCausalLM` | MoE | MLA attention, MTP speculative decoding | | [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) | `MixtralForCausalLM` | MoE | 8 experts, top-2 routing | | [GLM-4-MoE](https://huggingface.co/THUDM) | `Glm4MoeForCausalLM` | MoE | | -| [GPT-OSS](https://huggingface.co/openai) | `GptOssForCausalLM` | Dense | Sliding window + attention sinks | +| [GPT-OSS](https://huggingface.co/openai) | `GptOssForCausalLM` | MoE | Sliding window + attention sinks | | [Kimi-K2](https://huggingface.co/moonshotai/Kimi-K2-Thinking) | via `--trust-remote-code` | MoE | See [recipe](recipes/Kimi-K2-Thinking.md) | ## 📋 Requirements From f37370efe60fc74cd68e767ecc5fde1a3f5efda1 Mon Sep 17 00:00:00 2001 From: Peng Sun Date: Wed, 18 Feb 2026 18:29:41 -0600 Subject: [PATCH 3/8] Enable Triton-only prefill attention and ENABLE_CK Docker support 1. Route prefill to prefill_attention_triton when use_triton_attn=True (models with head_dim!=128 or sliding_window). Previously prefill always used CK-based flash_attn_varlen_func, which fails when AITER is built with ENABLE_CK=0. 2. Create fake_block_tables inline from cu_seqlens_k in prefill_attention_triton. The attn_metadata.fake_block_tables field was never populated, causing NoneType stride errors. 3. Dockerfile: add ENABLE_CK build arg for Triton-only AITER builds, install triton_kernels package (required for MXFP4 MoE on gfx94x), and conditionally skip CK submodule init when ENABLE_CK=0. Tested with GPT-OSS-120B (head_dim=64, MXFP4 MoE, sliding_window=128) on MI300X using ENABLE_CK=0 Docker image. --- atom/model_ops/attention_mha.py | 16 +++++++++++++++- docker/Dockerfile | 13 +++++++++++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index b9e0a286e..496f79cf5 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -370,7 +370,19 @@ def prefill_attention_triton( if ctx.is_prefill: k_cache = k.unsqueeze(1) v_cache = v.unsqueeze(1) - block_tables = attn_metadata.fake_block_tables + # Create fake block_tables for prefill: each token is its own + # "block" (block_size=1). Shape [num_seqs, max_seqlen_k]. + batch_size = attn_metadata.cu_seqlens_k.shape[0] - 1 + max_len = attn_metadata.max_seqlen_k + block_tables = torch.zeros( + batch_size, max_len, dtype=torch.int32, device=q.device + ) + for i in range(batch_size): + s = attn_metadata.cu_seqlens_k[i].item() + e = attn_metadata.cu_seqlens_k[i + 1].item() + block_tables[i, : e - s] = torch.arange( + s, e, dtype=torch.int32, device=q.device + ) o = torch.empty_like(q) descale_shape = (attn_metadata.cu_seqlens_q.shape[0] - 1, k.shape[1]) @@ -407,6 +419,8 @@ def dispatch_backend(self, fwd_ctx: ForwardContext): ctx = fwd_ctx.context if ctx.is_prefill: + if self.use_triton_attn: + return self.prefill_attention_triton return self.prefill_attention else: if self.use_triton_attn: diff --git a/docker/Dockerfile b/docker/Dockerfile index 85c99daac..7d6b9d837 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -14,6 +14,8 @@ ARG AITER_COMMIT="HEAD" ARG MORI_COMMIT="b0dce4beebeb1f26c784eee17d5fd9785ee9447f" ARG PREBUILD_KERNELS=1 ARG MAX_JOBS +# Set ENABLE_CK=0 to skip CK/ASM modules for a fast Triton-only AITER build +ARG ENABLE_CK=1 RUN pip install --upgrade pip RUN pip install lm-eval[api] @@ -63,6 +65,10 @@ RUN git clone --depth=1 --branch release/internal/3.5.x https://github.com/ROCm/ MAX_JOBS=64 pip --retries=10 --default-timeout=60 install . RUN pip show triton || true +# Install triton_kernels (required for MXFP4 MoE on gfx94x) +RUN pip install --no-deps -e /triton-test/python/triton_kernels/ +RUN pip show triton-kernels || true + # Install Aiter RUN mkdir -p /app RUN pip uninstall -y aiter || true @@ -70,8 +76,11 @@ RUN git clone $AITER_REPO /app/aiter-test && \ cd /app/aiter-test && \ pip install -r requirements.txt && \ git checkout $AITER_COMMIT && \ - git submodule sync && git submodule update --init --recursive && \ - MAX_JOBS=$MAX_JOBS PREBUILD_KERNELS=$PREBUILD_KERNELS GPU_ARCHS=$GPU_ARCH_LIST python3 setup.py develop + if [ "$ENABLE_CK" != "0" ]; then \ + git submodule sync && git submodule update --init --recursive; \ + fi && \ + ENABLE_CK=$ENABLE_CK MAX_JOBS=$MAX_JOBS PREBUILD_KERNELS=$PREBUILD_KERNELS \ + GPU_ARCHS=$GPU_ARCH_LIST python3 setup.py develop RUN pip show amd-aiter || true From 3c6872d98c31e249157aefd099168f618bd0db1b Mon Sep 17 00:00:00 2001 From: Peng Sun Date: Thu, 19 Feb 2026 16:58:46 -0600 Subject: [PATCH 4/8] Add nightly upstream sync workflow --- .github/workflows/sync-upstream.yaml | 42 ++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 .github/workflows/sync-upstream.yaml diff --git a/.github/workflows/sync-upstream.yaml b/.github/workflows/sync-upstream.yaml new file mode 100644 index 000000000..2afda6a17 --- /dev/null +++ b/.github/workflows/sync-upstream.yaml @@ -0,0 +1,42 @@ +name: Sync upstream main + +on: + schedule: + # Run nightly at 06:00 UTC (midnight CST) + - cron: '0 6 * * *' + workflow_dispatch: # Allow manual trigger + +jobs: + sync: + runs-on: ubuntu-latest + steps: + - name: Checkout fork + uses: actions/checkout@v4 + with: + ref: main + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Add upstream remote + run: git remote add upstream https://github.com/ROCm/ATOM.git + + - name: Fetch upstream + run: git fetch upstream main + + - name: Check for new commits + id: check + run: | + BEHIND=$(git rev-list --count HEAD..upstream/main) + echo "behind=$BEHIND" >> "$GITHUB_OUTPUT" + echo "Fork is $BEHIND commit(s) behind upstream" + + - name: Merge upstream + if: steps.check.outputs.behind != '0' + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git merge upstream/main --no-edit + + - name: Push + if: steps.check.outputs.behind != '0' + run: git push origin main From a3168bdc3b3427b48bc0f21cb30ee8c455eb9371 Mon Sep 17 00:00:00 2001 From: Peng Date: Thu, 19 Feb 2026 20:13:41 -0600 Subject: [PATCH 5/8] Add multi-stage Docker build with MORI, FlyDSL and Triton MOE fallback * Add Dockerfile.clean for minimal ATOM build from public sources - Dockerfile.clean: clean Docker build using rocm/dev-ubuntu-22.04:7.2-complete base, PyTorch nightly ROCm 7.2 wheel, ROCm Triton 3.5.x from source (replaces incompatible Triton 3.6.0), and AITER Triton-only build (ENABLE_CK=0) - Fix scheduler.py: initialize num_rejected=0 before speculative-decode branch to prevent UnboundLocalError in non-speculative path (regression from #219) - Fix test_scheduler.py: add required num_rejected param to ScheduledBatchOutput - Add .dockerignore to exclude .git and build artifacts from Docker context * Fix shard_offset UnboundLocalError in MergedColumnParallelLinear for per_Token/per_1x32 quant types weight_loader() only handled per_1x128 and per_Tensor quant types when computing shard_offset for scale parameters. For per_Token and per_1x32 quant types (used by DeepSeek-R1 FP8), shard_offset was left undefined causing UnboundLocalError. Add else clause with same shard_offset logic as normal weights. * Add multi-stage wheel build and Triton MOE fallback - Dockerfile.wheels: builder stage that compiles/downloads all wheels (PyTorch ROCm 7.2, Triton 3.5.x, AITER with ENABLE_CK=0) - Dockerfile.clean: rewritten for zero-compilation install from pre-built wheels via bind-mount (37.9GB vs 67.9GB) - moe.py: add Triton MOE fallback for FP8 when CK sorting kernel is unavailable (CompressedTensorsFp8MoEMethod + Fp8MoEMethod), skip weight shuffle in Triton path * Add MORI and FlyDSL wheel builds to Docker multi-stage build Dockerfile.wheels: - Add LLVM/MLIR build from ROCm/llvm-project (blobless clone for speed) - Add FlyDSL wheel build from ROCm/FlyDSL source - Add MORI wheel build from ROCm/mori source - Patch Caffe2Config.cmake to work with ROCm nightly torch - Filter torch/triton from MORI requirements to preserve ROCm wheels Dockerfile.clean: - Add openmpi-bin, libopenmpi-dev, libdw1 for MORI runtime - Install mori and flydsl wheels before amd_aiter - Add FlyDSL import check; use pip show for MORI (segfaults without GPU) --- .dockerignore | 6 + atom/model_engine/scheduler.py | 1 + atom/model_ops/linear.py | 4 + atom/model_ops/moe.py | 263 +++++++++++++++++++++++++++++++-- docker/Dockerfile.clean | 69 +++++++++ docker/Dockerfile.wheels | 159 ++++++++++++++++++++ tests/test_scheduler.py | 12 +- 7 files changed, 501 insertions(+), 13 deletions(-) create mode 100644 .dockerignore create mode 100644 docker/Dockerfile.clean create mode 100644 docker/Dockerfile.wheels diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..c69aefffd --- /dev/null +++ b/.dockerignore @@ -0,0 +1,6 @@ +.git +__pycache__ +*.pyc +*.egg-info +build/ +dist/ diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index e27d3af61..5e169a1b5 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -297,6 +297,7 @@ def postprocess( continue token_ids = prev_token_ids[seq.id] num_new_token = len(token_ids) + num_rejected = 0 self.update_spec_stats(num_new_token) idx = fwd_output.req_ids.index(seq.id) if is_deferred_out or self.use_spec: diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 1bd3538a3..83b492a99 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -700,6 +700,10 @@ def weight_loader( elif self.quant_type == QuantType.per_Tensor: shard_offset = loaded_shard_id shard_size = 1 + else: + # per_Token and per_1x32: scale dim 0 matches output_size + shard_offset = sum(self.output_sizes[:loaded_shard_id]) + shard_size = self.output_sizes[loaded_shard_id] else: shard_offset = sum(self.output_sizes[:loaded_shard_id]) shard_size = self.output_sizes[loaded_shard_id] diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index d9dc1e34c..588cc47a7 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -48,6 +48,200 @@ from torch import nn from transformers import PretrainedConfig +import logging + +_moe_logger = logging.getLogger(__name__) + + +def _has_ck_moe_sorting() -> bool: + """Check if CK MOE sorting kernel is available.""" + try: + import importlib + + return importlib.util.find_spec("aiter.jit.module_moe_sorting") is not None + except Exception: + return False + + +def _per_token_group_quant_fp8(x, group_size, fp8_dtype): + """Quantize input tensor to FP8 with per-token-group scaling. + + Args: + x: Input tensor of shape (M, K) in bf16/fp16. + group_size: Number of elements per quantization group. + fp8_dtype: Target FP8 dtype (e.g. torch.float8_e4m3fnuz). + + Returns: + x_fp8: Quantized tensor of shape (M, K). + scale: Dequantization scale of shape (M, K // group_size). + """ + M, K = x.shape + assert K % group_size == 0 + num_groups = K // group_size + x_float = x.float() + x_grouped = x_float.view(M, num_groups, group_size) + max_abs = x_grouped.abs().amax(dim=-1) # (M, num_groups) + fp8_max = torch.finfo(fp8_dtype).max + scale = (max_abs / fp8_max).clamp(min=1e-12) + x_scaled = x_grouped / scale.unsqueeze(-1) + x_fp8 = x_scaled.clamp(-fp8_max, fp8_max).to(fp8_dtype) + x_fp8 = x_fp8.view(M, K) + return x_fp8, scale + + +def _triton_fp8_moe( + x, + w13, + w2, + topk_weights, + topk_ids, + w13_scale, + w2_scale, + top_k, + block_quant, + quant_type, +): + """Execute FP8 MOE using AITER Triton kernels (no CK dependency). + + Two-stage pipeline: + Stage 1 (GEMM1+SiLU): x @ w13^T with SiLU gating + Stage 2 (GEMM2): intermediate @ w2^T with routing weight accumulation + + For GEMM2, we reshape the intermediate so each (token, expert_k) pair is + treated as an independent token with top_k=1, allowing correct A indexing. + """ + import triton.language as tl + from aiter.ops.triton.moe.moe_align_block_size import moe_align_block_size_triton + from aiter.ops.triton.moe.moe_op_silu_fused import fused_moe_silu + from aiter.ops.triton.moe.moe_op import fused_moe as triton_fused_moe + from aiter.ops.triton.utils.moe_config_utils import get_optimal_moe_config + + M, hidden_dim = x.shape + E = w13.shape[0] + inter_dim_2 = w13.shape[1] # 2 * inter_dim + inter_dim = inter_dim_2 // 2 + + if block_quant: + if quant_type == QuantType.per_1x128: + block_shape = [128, 128] + elif quant_type == QuantType.per_1x32: + block_shape = [1, 32] + else: + block_shape = None + else: + block_shape = None + + config = get_optimal_moe_config(dtype=x.dtype, use_fp8_w8a8=True, M=M) + block_size_m = config["BLOCK_SIZE_M"] + compute_type = tl.bfloat16 if x.dtype == torch.bfloat16 else tl.float16 + + # --- Stage 1: Sorting --- + max_num_tokens_padded = topk_ids.numel() + E * (block_size_m - 1) + sorted_token_ids = torch.empty( + max_num_tokens_padded, dtype=torch.int32, device=x.device + ) + sorted_token_ids.fill_(topk_ids.numel()) + max_num_m_blocks = (max_num_tokens_padded + block_size_m - 1) // block_size_m + expert_ids = torch.empty(max_num_m_blocks, dtype=torch.int32, device=x.device) + num_tokens_post_pad = torch.empty(1, dtype=torch.int32, device=x.device) + + moe_align_block_size_triton( + topk_ids, E, block_size_m, sorted_token_ids, expert_ids, num_tokens_post_pad + ) + + # --- Stage 2: GEMM1 with SiLU (x @ w13^T) --- + if block_quant and block_shape is not None: + block_k = block_shape[1] + a_fp8, a_scale = _per_token_group_quant_fp8(x, block_k, w13.dtype) + else: + a_fp8 = x + a_scale = None + + intermediate = torch.zeros(M * top_k, inter_dim, dtype=x.dtype, device=x.device) + + fused_moe_silu( + A=a_fp8, + B=w13, + C=intermediate, + A_scale=a_scale, + B_scale=w13_scale, + B_zp=None, + topk_weights=topk_weights, + topk_ids=topk_ids, + sorted_token_ids=sorted_token_ids, + expert_ids=expert_ids, + num_tokens_post_padded=num_tokens_post_pad, + mul_routed_weight=False, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=True, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=block_shape, + config=config, + ) + + # --- Stage 3: GEMM2 (intermediate @ w2^T) --- + # Reshape for GEMM2: treat each (token, expert_k) as independent token + # with top_k=1 so the kernel indexes A correctly (A // top_k = A // 1 = A) + gemm2_topk_ids = topk_ids.reshape(M * top_k, 1) + gemm2_topk_weights = topk_weights.reshape(M * top_k, 1) + + # Re-sort for GEMM2 with the reshaped topk_ids + gemm2_max_padded = gemm2_topk_ids.numel() + E * (block_size_m - 1) + gemm2_sorted_ids = torch.empty(gemm2_max_padded, dtype=torch.int32, device=x.device) + gemm2_sorted_ids.fill_(gemm2_topk_ids.numel()) + gemm2_max_blocks = (gemm2_max_padded + block_size_m - 1) // block_size_m + gemm2_expert_ids = torch.empty(gemm2_max_blocks, dtype=torch.int32, device=x.device) + gemm2_num_pad = torch.empty(1, dtype=torch.int32, device=x.device) + + moe_align_block_size_triton( + gemm2_topk_ids, + E, + block_size_m, + gemm2_sorted_ids, + gemm2_expert_ids, + gemm2_num_pad, + ) + + # Quantize intermediate for FP8 GEMM2 + if block_quant and block_shape is not None: + block_k2 = block_shape[1] + inter_fp8, inter_scale = _per_token_group_quant_fp8( + intermediate, block_k2, w2.dtype + ) + else: + inter_fp8 = intermediate + inter_scale = None + + output = torch.zeros(M * top_k, 1, hidden_dim, dtype=x.dtype, device=x.device) + + triton_fused_moe( + A=inter_fp8, + B=w2, + C=output, + A_scale=inter_scale, + B_scale=w2_scale, + B_zp=None, + topk_weights=gemm2_topk_weights, + topk_ids=gemm2_topk_ids, + sorted_token_ids=gemm2_sorted_ids, + expert_ids=gemm2_expert_ids, + num_tokens_post_padded=gemm2_num_pad, + mul_routed_weight=True, + top_k=1, + compute_type=compute_type, + use_fp8_w8a8=True, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=block_shape, + config=config, + ) + + # Reduce: sum across top_k experts per token + result = output.squeeze(1).view(M, top_k, hidden_dim).sum(dim=1) + return result + class FusedMoeWeightScaleSupported(Enum): """Supported quantization strategies for MoE weight scales.""" @@ -980,6 +1174,14 @@ def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): self.block_n = 1 self.block_k = 32 + # Detect CK MOE availability; fall back to Triton MOE if unavailable + self.use_triton_moe = not _has_ck_moe_sorting() + if self.use_triton_moe: + _moe_logger.info( + "CK MOE sorting not available, using Triton MOE kernels " + "for CompressedTensors FP8" + ) + def create_weights( self, layer: torch.nn.Module, @@ -1220,16 +1422,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) # Shuffle weights for asm moe (moved from inference to load time for better performance) - if w13.dtype in [ - torch.int8, - torch.uint8, - torch.float8_e4m3fnuz, - torch.float8_e4m3fn, - ]: - from aiter.ops.shuffle import shuffle_weight + # Skip shuffle when using Triton path (Triton expects standard row-major) + if not self.use_triton_moe: + if w13.dtype in [ + torch.int8, + torch.uint8, + torch.float8_e4m3fnuz, + torch.float8_e4m3fn, + ]: + from aiter.ops.shuffle import shuffle_weight - w13.data = shuffle_weight(w13.data) - w2.data = shuffle_weight(w2.data) + w13.data = shuffle_weight(w13.data) + w2.data = shuffle_weight(w2.data) # Call parent class for any additional processing super().process_weights_after_loading(layer) @@ -1298,6 +1502,21 @@ def apply( a1_scale = getattr(layer, "w13_input_scale", None) a2_scale = getattr(layer, "w2_input_scale", None) + # Triton MOE fallback when CK is not available + if self.use_triton_moe: + return _triton_fp8_moe( + x=x, + w13=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + top_k=top_k, + block_quant=self.block_quant, + quant_type=self.quant_type, + ) + # Use modular kernel if available (for EP/DP setups) # Otherwise fall back to direct kernel call if self.fused_experts is not None: @@ -1362,6 +1581,12 @@ def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): self.need_normalize_e4m3fn_to_e4m3fnuz = ( self.quant_dtype == torch.float8_e4m3fnuz ) + # Detect CK MOE availability; fall back to Triton MOE if unavailable + self.use_triton_moe = not _has_ck_moe_sorting() + if self.use_triton_moe: + _moe_logger.info( + "CK MOE sorting not available, using Triton MOE kernels for FP8" + ) def create_weights( self, @@ -1525,7 +1750,8 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: layer.w2_weight = nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight_scale = nn.Parameter(w2_weight_scale, requires_grad=False) - shuffle_weights(layer.w13_weight, layer.w2_weight) + if not self.use_triton_moe: + shuffle_weights(layer.w13_weight, layer.w2_weight) return else: @@ -1597,7 +1823,8 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: ) start += shard_size - shuffle_weights(layer.w13_weight, layer.w2_weight) + if not self.use_triton_moe: + shuffle_weights(layer.w13_weight, layer.w2_weight) layer.w13_weight_scale = torch.nn.Parameter( max_w13_scales, requires_grad=False @@ -1647,6 +1874,20 @@ def apply( num_fused_shared_experts=layer.num_fused_shared_experts, routed_scaling_factor=layer.routed_scaling_factor, ) + # Triton MOE fallback when CK is not available + if self.use_triton_moe: + return _triton_fp8_moe( + x=x, + w13=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + top_k=top_k, + block_quant=self.block_quant, + quant_type=self.quant_type, + ) # per_Tensor not support num_local_tokens so not use mori if self.quant_type == QuantType.per_Tensor or self.fused_experts is None: return torch.ops.aiter.rocm_aiter_fused_moe( diff --git a/docker/Dockerfile.clean b/docker/Dockerfile.clean new file mode 100644 index 000000000..4a2d4dbc6 --- /dev/null +++ b/docker/Dockerfile.clean @@ -0,0 +1,69 @@ +# Dockerfile.clean — Wheel-only ATOM/AITER build (zero source compilation) +# +# Base: rocm/dev-ubuntu-24.04:7.2-complete (Python 3.12, full ROCm runtime) +# All packages installed from pre-built wheels — no git clones, no compiles. +# +# Option A — from pre-built wheels directory: +# cd /home/pensun/ATOM +# DOCKER_BUILDKIT=1 docker build \ +# --build-context wheels=/home/pensun/dist \ +# -f docker/Dockerfile.clean -t atom:clean . +# +# Option B — multi-stage from Dockerfile.wheels builder image: +# docker build -f docker/Dockerfile.wheels -t atom:wheels . +# DOCKER_BUILDKIT=1 docker build \ +# --build-context wheels=docker-image://atom:wheels \ +# -f docker/Dockerfile.clean -t atom:clean . +# +# Run: +# docker run --rm --device=/dev/kfd --device=/dev/dri \ +# -v /data2/models:/models atom:clean bash + +ARG BASE_IMAGE="rocm/dev-ubuntu-24.04:7.2-complete" +FROM ${BASE_IMAGE} + +ENV DEBIAN_FRONTEND=noninteractive + +# ── 1. System packages (minimal — no build tools needed) ───────────── +RUN apt-get update && apt-get install -y --no-install-recommends \ + git python3-pip python3-dev \ + ibverbs-utils libpci-dev locales \ + openmpi-bin libopenmpi-dev libdw1 \ + && rm -rf /var/lib/apt/lists/* + +RUN pip3 install --break-system-packages --ignore-installed pip setuptools wheel + +# ── 2. Install all pre-built wheels ──────────────────────────────────── +# Uses bind-mount to avoid a 60+ GB COPY layer from the wheels image. +# Works with both Option A (flat directory) and Option B (docker-image://). +RUN --mount=type=bind,from=wheels,source=/,target=/mnt/wheels \ + mkdir -p /tmp/wheels \ + && find /mnt/wheels -name '*.whl' -exec cp {} /tmp/wheels/ \; \ + && ls -lhS /tmp/wheels/*.whl \ + && pip3 install --break-system-packages --no-deps \ + /tmp/wheels/torch-*.whl \ + /tmp/wheels/torchvision-*.whl \ + /tmp/wheels/torchaudio-*.whl \ + /tmp/wheels/triton-*.whl \ + /tmp/wheels/triton_kernels-*.whl \ + && pip3 install --break-system-packages \ + filelock typing-extensions sympy networkx jinja2 fsspec numpy pillow \ + && pip3 install --break-system-packages \ + /tmp/wheels/mori-*.whl \ + /tmp/wheels/flydsl-*.whl \ + && pip3 install --break-system-packages \ + /tmp/wheels/amd_aiter-*.whl \ + && rm -rf /tmp/wheels \ + && python3 -c "import torch; print(f'PyTorch {torch.__version__}, ROCm: {torch.version.hip}')" \ + && python3 -c "import triton; print(f'Triton {triton.__version__}')" \ + && python3 -c "import aiter; print('AITER OK')" \ + && python3 -c "import flydsl; print('FlyDSL OK')" \ + && pip3 show mori && echo "MORI wheel installed OK" + +# ── 3. ATOM (from build context — pure Python, instant install) ────── +COPY . /app/ATOM +RUN cd /app/ATOM && pip3 install --break-system-packages -e . \ + && python3 -c "import atom; print('ATOM OK')" + +WORKDIR /app/ATOM +CMD ["/bin/bash"] diff --git a/docker/Dockerfile.wheels b/docker/Dockerfile.wheels new file mode 100644 index 000000000..e9da1a648 --- /dev/null +++ b/docker/Dockerfile.wheels @@ -0,0 +1,159 @@ +# Dockerfile.wheels — Build/download all wheels for ATOM clean install +# +# Produces /wheels/ containing: +# torch, torchvision, torchaudio (pulled from PyTorch nightly) +# triton 3.5.x (built from ROCm/triton source) +# triton_kernels (built from ROCm/triton source) +# flydsl (built from FlyDSL source + embedded MLIR runtime) +# mori (built from MORI source) +# amd_aiter (built with ENABLE_CK=0 + pre-compiled Triton kernels) +# +# Usage (standalone — extract wheels to host): +# docker build -f docker/Dockerfile.wheels -t atom:wheels . +# docker run --rm atom:wheels tar cf - /wheels | tar xf - -C /home/pensun/dist --strip-components=1 +# +# Usage (multi-stage — pipe directly into Dockerfile.clean): +# DOCKER_BUILDKIT=1 docker build \ +# --build-context wheels=docker-image://atom:wheels \ +# -f docker/Dockerfile.clean -t atom:clean . + +ARG BASE_IMAGE="rocm/dev-ubuntu-24.04:7.2-complete" +FROM ${BASE_IMAGE} + +ARG GPU_ARCH="gfx942;gfx950" +ARG AITER_REPO="https://github.com/sunway513/aiter.git" +ARG AITER_BRANCH="feat/prebuild-triton" +ARG FLYDSL_REPO="https://github.com/ROCm/FlyDSL.git" +ARG FLYDSL_BRANCH="main" +ARG LLVM_COMMIT="04f968b02917" +ARG MORI_REPO="https://github.com/ROCm/mori.git" +ARG MORI_COMMIT="b0dce4beebeb1f26c784eee17d5fd9785ee9447f" +ARG MAX_JOBS="" +ARG PREBUILD_TRITON=1 + +ENV GPU_ARCH_LIST=${GPU_ARCH} +ENV PYTORCH_ROCM_ARCH=${GPU_ARCH} +ENV DEBIAN_FRONTEND=noninteractive + +# ── 1. System packages + build tools ──────────────────────────────── +RUN apt-get update && apt-get install -y --no-install-recommends \ + git cmake ninja-build \ + python3-pip python3-dev python3-venv \ + ibverbs-utils libpci-dev locales \ + && rm -rf /var/lib/apt/lists/* + +RUN pip3 install --break-system-packages --ignore-installed \ + pip setuptools wheel build + +RUN mkdir -p /wheels + +# ── 2. Pull PyTorch ROCm 7.2 nightly wheels ───────────────────────── +RUN pip3 download --no-deps --dest /wheels \ + torch torchvision torchaudio \ + --index-url https://download.pytorch.org/whl/nightly/rocm7.2 + +# ── 3. Build Triton 3.5.x from ROCm fork ──────────────────────────── +RUN git clone --depth=1 --branch release/internal/3.5.x \ + https://github.com/ROCm/triton.git /build/triton + +RUN cd /build/triton \ + && pip3 install --break-system-packages -r python/requirements.txt \ + && pip3 install --break-system-packages filecheck \ + && MAX_JOBS=${MAX_JOBS:-64} pip3 wheel \ + --no-build-isolation --no-deps -w /wheels . \ + && ls -lh /wheels/triton-*.whl + +# Build triton_kernels wheel +RUN cd /build/triton/python/triton_kernels \ + && pip3 wheel --no-deps -w /wheels . \ + && ls -lh /wheels/triton_kernels-*.whl + +# ── 4. Build LLVM/MLIR for FlyDSL ─────────────────────────────────── +# Blobless clone (~6 min vs ~30 min full clone). LLVM_COMMIT rarely +# changes, so this layer stays cached across most rebuilds. +RUN pip3 install --break-system-packages nanobind numpy pybind11 + +RUN git clone --filter=blob:none --no-checkout \ + https://github.com/ROCm/llvm-project.git /build/llvm-project \ + && cd /build/llvm-project \ + && git fetch origin amd-staging \ + && git checkout ${LLVM_COMMIT} + +RUN mkdir -p /build/llvm-project/buildmlir \ + && cd /build/llvm-project/buildmlir \ + && NANOBIND_DIR=$(python3 -c "import nanobind; import os; print(os.path.dirname(nanobind.__file__) + '/cmake')") \ + && cmake -G Ninja \ + -S /build/llvm-project/llvm \ + -DLLVM_ENABLE_PROJECTS="mlir;clang" \ + -DLLVM_TARGETS_TO_BUILD="X86;NVPTX;AMDGPU" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CXX_STANDARD=17 \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_INSTALL_UTILS=ON \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DPython3_EXECUTABLE=$(which python3) \ + -Dnanobind_DIR="$NANOBIND_DIR" \ + -DBUILD_SHARED_LIBS=OFF \ + && cmake --build . -j$(nproc) \ + && cmake --install . --prefix /build/llvm-project/mlir_install + +# ── 5. Install torch + triton (needed for AITER/MORI builds) ──────── +RUN pip3 install --break-system-packages --no-deps \ + /wheels/torch-*.whl /wheels/triton-3.5*.whl \ + && pip3 install --break-system-packages \ + filelock typing-extensions sympy networkx jinja2 fsspec numpy + +# ── 6. Build FlyDSL wheel ─────────────────────────────────────────── +RUN git clone --depth=1 --branch ${FLYDSL_BRANCH} ${FLYDSL_REPO} /build/FlyDSL + +RUN cd /build/FlyDSL \ + && export MLIR_PATH=/build/llvm-project/mlir_install \ + && bash flir/build.sh \ + && export FLIR_IN_BUILD_SH=1 \ + && pip3 install --break-system-packages auditwheel patchelf \ + && python3 setup.py bdist_wheel \ + && cp dist/flydsl-*.whl /wheels/ \ + && ls -lh /wheels/flydsl-*.whl + +# ── 7. Build MORI wheel ───────────────────────────────────────────── +RUN apt-get update && apt-get install -y --no-install-recommends \ + openmpi-bin libopenmpi-dev cython3 libdw1 \ + && rm -rf /var/lib/apt/lists/* + +# Patch PyTorch's Caffe2Config.cmake: the ROCm nightly wheel's config +# hard-errors when CUDA toolkit is not found, even though we only need ROCm. +# Convert the fatal error to a warning so MORI (and other torch-cmake users) +# can build against the ROCm PyTorch wheel without CUDA installed. +RUN CAFFE2_CFG=$(python3 -c "import torch, pathlib; print(pathlib.Path(torch.__file__).parent / 'share/cmake/Caffe2/Caffe2Config.cmake')") \ + && sed -i 's/message(FATAL_ERROR "Your installed Caffe2 version uses CUDA/message(WARNING "Skipped: Your installed Caffe2 version uses CUDA/' "$CAFFE2_CFG" + +RUN git clone ${MORI_REPO} /build/mori \ + && cd /build/mori \ + && git checkout ${MORI_COMMIT} \ + && grep -iv '^torch\|^triton' requirements-build.txt \ + | pip3 install --break-system-packages -r /dev/stdin \ + && git submodule update --init --recursive \ + && pip3 wheel --no-build-isolation --no-deps -w /wheels . \ + && ls -lh /wheels/mori-*.whl + +# ── 8. Build AITER wheel (ENABLE_CK=0, pre-compiled Triton kernels) ── +RUN git clone --depth=1 --branch ${AITER_BRANCH} ${AITER_REPO} /build/aiter + +# Set AITER build env for all subsequent commands in this layer +RUN cd /build/aiter \ + && pip3 install --break-system-packages -r requirements.txt \ + && export ENABLE_CK=0 PREBUILD_TRITON=${PREBUILD_TRITON} \ + PREBUILD_TRITON_ARCHS="gfx942,gfx950" \ + MAX_JOBS=${MAX_JOBS} GPU_ARCHS=${GPU_ARCH_LIST} \ + && pip3 install --break-system-packages --no-build-isolation -e . \ + && python3 -c "import aiter; print('editable install OK')" \ + && echo "install" > aiter/install_mode \ + && python3 setup.py bdist_wheel \ + && cp dist/amd_aiter-*.whl /wheels/ \ + && ls -lh /wheels/amd_aiter-*.whl + +# ── 9. Summary ────────────────────────────────────────────────────── +RUN echo "=== Wheel inventory ===" && ls -lhS /wheels/*.whl && echo "=== Done ===" + +WORKDIR /wheels +CMD ["ls", "-lhS", "/wheels/"] diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 324c10a9c..48155c046 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: MIT # Tests for atom/model_engine/scheduler.py — public API only +import numpy as np + from atom.model_engine.scheduler import Scheduler, ScheduledBatchOutput from atom.model_engine.sequence import SequenceStatus, SequenceType from atom.sampling_params import SamplingParams @@ -121,7 +123,9 @@ def _prefill(self, scheduler, seq): def _output(self, seq_id, tokens): return ScheduledBatchOutput( - token_ids={seq_id: tuple(tokens)}, draft_token_ids=None + token_ids={seq_id: tuple(tokens)}, + num_rejected=np.zeros(0, dtype=np.int32), + draft_token_ids=None, ) def test_appends_token(self, scheduler, seq_factory): @@ -166,7 +170,11 @@ def test_stop_token_ids(self, seq_factory): sched.schedule() finished = sched.postprocess( list(sched.running), - ScheduledBatchOutput(token_ids={seq.id: (99,)}, draft_token_ids=None), + ScheduledBatchOutput( + token_ids={seq.id: (99,)}, + num_rejected=np.zeros(0, dtype=np.int32), + draft_token_ids=None, + ), ) assert len(finished) == 1 assert "stop_99" in finished[0].leave_reason From 9c6a7b83ac4bfff08865a5fe1f47a38482adce97 Mon Sep 17 00:00:00 2001 From: Peng Date: Thu, 19 Feb 2026 22:53:47 -0600 Subject: [PATCH 6/8] Fix DeepSeek R1 for Triton-only build (ENABLE_CK=0) (#10) - MOE: handle fused shared expert top_k mismatch (actual_top_k = topk_ids.numel() // M) - MLA prefill: replace flash_attn_varlen_func with PyTorch SDPA (no CK dependency) - MHA prefill: always dispatch to prefill_attention_triton (layer 0 uses MHA) - aiter_mla: populate paged KV metadata (kv_indptr, kv_indices) in prepare_prefill --- atom/model_ops/attention_mha.py | 5 ++- atom/model_ops/attention_mla.py | 29 ++++++++------- atom/model_ops/attentions/aiter_mla.py | 50 ++++++++++++++++++++++++++ atom/model_ops/moe.py | 18 ++++++---- 4 files changed, 80 insertions(+), 22 deletions(-) diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index 496f79cf5..74622651b 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -419,9 +419,8 @@ def dispatch_backend(self, fwd_ctx: ForwardContext): ctx = fwd_ctx.context if ctx.is_prefill: - if self.use_triton_attn: - return self.prefill_attention_triton - return self.prefill_attention + # Always use Triton prefill (no CK/flash_attn_varlen_func dependency) + return self.prefill_attention_triton else: if self.use_triton_attn: return self.paged_attention_triton diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 6b2452cd1..954ca8220 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -396,19 +396,22 @@ def _forward_prefill_mha( k = torch.cat((k_nope, k_rope.expand((*k_nope.shape[:-1], -1))), dim=-1) - output = flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=attn_metadata.cu_seqlens_q, - cu_seqlens_k=attn_metadata.cu_seqlens_k, - max_seqlen_q=attn_metadata.max_seqlen_q, - max_seqlen_k=attn_metadata.max_seqlen_k, - min_seqlen_q=attn_metadata.min_seqlen_q, - dropout_p=attn_metadata.dropout_p, - softmax_scale=self.scale, - causal=True, - ) + # Use PyTorch SDPA for MLA prefill attention (no CK dependency) + import torch.nn.functional as F + + cu_q = attn_metadata.cu_seqlens_q + cu_k = attn_metadata.cu_seqlens_k + num_seqs = cu_q.shape[0] - 1 + outputs = [] + for i in range(num_seqs): + qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0) + ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0) + vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0) + oi = F.scaled_dot_product_attention( + qi, ki, vi, is_causal=True, scale=self.scale + ) + outputs.append(oi.squeeze(0).transpose(0, 1)) + output = torch.cat(outputs, dim=0) return self.o_proj(output.flatten(start_dim=-2)) diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index 7fd33253b..89bbdf65e 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -180,6 +180,56 @@ def prepare_prefill(self, batch: ScheduledBatch): bs = batch.total_seqs_num_prefill sum_scheduled_tokens = batch.total_tokens_num_prefill var = self.model_runner.forward_vars + + # Prepare paged KV metadata for MLA prefill paths + # (needed by mla_prefill_fwd for bf16, unified_attention for fp8) + if batch.block_tables: + context_lens = np.asarray(batch.context_lens[:bs], dtype=np.int32) + num_blocks_per_seq = cdiv(context_lens, self.block_size) + kv_indptr = np.cumsum(num_blocks_per_seq) + sum_blocks = kv_indptr[-1] + + dst = var["kv_indices"].np + offset = 0 + for i in range(bs): + bt = batch.block_tables[i] + n = len(bt) + dst[offset : offset + n] = bt + offset += n + sum_blocks_before_converted = offset + + var["kv_indptr"].np[0] = 0 + var["kv_indptr"].np[1 : bs + 1] = kv_indptr + + attn_metadata.kv_indptr = var["kv_indptr"].copy_to_gpu(bs + 1) + attn_metadata.kv_indices = var["kv_indices"].copy_to_gpu( + sum_blocks_before_converted + ) + attn_metadata.kv_last_page_lens = var["kv_last_page_lens"].gpu[:bs] + + if self.block_ratio > 1: + kv_indices_convert_triton( + var["kv_indices"].gpu[:sum_blocks_before_converted], + var["kv_indices_converted"].gpu[:sum_blocks], + var["kv_indptr"].gpu[: bs + 1], + self.block_ratio, + self.block_size, + ) + attn_metadata.kv_indices = var["kv_indices_converted"].gpu[:sum_blocks] + + # Prepare block_tables for unified_attention (fp8 prefill) + if attn_metadata.block_tables is None: + self.prepare_block_tables(batch) + attn_metadata.block_tables = var["block_tables"].copy_to_gpu(bs) + if self.block_ratio > 1: + block_table_convert_triton( + var["block_tables"].gpu[:bs], + var["block_tables_converted"].gpu[:bs], + var["context_lens"].gpu[:bs], + self.block_ratio, + ) + attn_metadata.block_tables = var["block_tables_converted"].gpu[:bs] + if self.is_sparse and attn_metadata.max_seqlen_k > self.index_topk: if attn_metadata.block_tables is None: self.prepare_block_tables(batch) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 588cc47a7..125d0d206 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -120,6 +120,8 @@ def _triton_fp8_moe( E = w13.shape[0] inter_dim_2 = w13.shape[1] # 2 * inter_dim inter_dim = inter_dim_2 // 2 + # When fused shared experts are enabled, topk_ids has M*(top_k+1) elements + actual_top_k = topk_ids.numel() // M if block_quant: if quant_type == QuantType.per_1x128: @@ -157,7 +159,9 @@ def _triton_fp8_moe( a_fp8 = x a_scale = None - intermediate = torch.zeros(M * top_k, inter_dim, dtype=x.dtype, device=x.device) + intermediate = torch.zeros( + M * actual_top_k, inter_dim, dtype=x.dtype, device=x.device + ) fused_moe_silu( A=a_fp8, @@ -172,7 +176,7 @@ def _triton_fp8_moe( expert_ids=expert_ids, num_tokens_post_padded=num_tokens_post_pad, mul_routed_weight=False, - top_k=top_k, + top_k=actual_top_k, compute_type=compute_type, use_fp8_w8a8=True, use_int8_w8a16=False, @@ -184,8 +188,8 @@ def _triton_fp8_moe( # --- Stage 3: GEMM2 (intermediate @ w2^T) --- # Reshape for GEMM2: treat each (token, expert_k) as independent token # with top_k=1 so the kernel indexes A correctly (A // top_k = A // 1 = A) - gemm2_topk_ids = topk_ids.reshape(M * top_k, 1) - gemm2_topk_weights = topk_weights.reshape(M * top_k, 1) + gemm2_topk_ids = topk_ids.reshape(M * actual_top_k, 1) + gemm2_topk_weights = topk_weights.reshape(M * actual_top_k, 1) # Re-sort for GEMM2 with the reshaped topk_ids gemm2_max_padded = gemm2_topk_ids.numel() + E * (block_size_m - 1) @@ -214,7 +218,9 @@ def _triton_fp8_moe( inter_fp8 = intermediate inter_scale = None - output = torch.zeros(M * top_k, 1, hidden_dim, dtype=x.dtype, device=x.device) + output = torch.zeros( + M * actual_top_k, 1, hidden_dim, dtype=x.dtype, device=x.device + ) triton_fused_moe( A=inter_fp8, @@ -239,7 +245,7 @@ def _triton_fp8_moe( ) # Reduce: sum across top_k experts per token - result = output.squeeze(1).view(M, top_k, hidden_dim).sum(dim=1) + result = output.squeeze(1).view(M, actual_top_k, hidden_dim).sum(dim=1) return result From 95a551704540a390243bf18b6d1609eb15082db2 Mon Sep 17 00:00:00 2001 From: Peng Sun Date: Thu, 19 Feb 2026 23:38:11 -0600 Subject: [PATCH 7/8] Enable Triton MOE path for MXFP4 on gfx950 (MI355X) The Triton MOE path (triton_kernels matmul_ogs + routing) was gated by gfx94* prefix check, excluding gfx950. Enable it on both gfx942 and gfx950 with graceful triton_kernels availability check. Also add CK-unavailable fallback: when CK MOE sorting is missing (e.g. ENABLE_CK=0 builds), automatically fall back to Triton if available. --- atom/model_ops/moe.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 125d0d206..ccceb3e72 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -834,11 +834,16 @@ def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): self.quant_type == QuantType.per_1x128 or self.quant_type == QuantType.per_1x32 ) - self.use_triton = get_gfx().startswith("gfx94") - if self.use_triton: - from atom.model_ops.utils import has_triton_kernels - - assert has_triton_kernels(), "triton_kernels is not installed" + from atom.model_ops.utils import has_triton_kernels + + self.use_triton = get_gfx() in ("gfx942", "gfx950") and has_triton_kernels() + if not self.use_triton and not _has_ck_moe_sorting(): + if has_triton_kernels(): + self.use_triton = True + _moe_logger.info( + "CK MOE sorting not available, " + "using Triton MOE kernels for MXFP4" + ) def create_weights( self, From 0b027c62d2e6cc99c9d3ea41ad26221b7c4fd3b9 Mon Sep 17 00:00:00 2001 From: Peng Sun Date: Fri, 20 Feb 2026 10:05:23 -0600 Subject: [PATCH 8/8] Fix triton_kernels compatibility for MI355X (gfx950) - Rename GFX950MXScaleLayout to CDNA4MXScaleLayout to match upstream triton_kernels (triton-lang/triton release/3.5.x) - Add block_m=128 constraint for gfx950 to avoid LDS overflow (default CDNA4 block_m=256 needs 162KB, MI355X limit is 160KB) --- atom/model_ops/fused_moe_triton.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/atom/model_ops/fused_moe_triton.py b/atom/model_ops/fused_moe_triton.py index 01d83e6e3..6324dae29 100644 --- a/atom/model_ops/fused_moe_triton.py +++ b/atom/model_ops/fused_moe_triton.py @@ -34,6 +34,11 @@ from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs from triton_kernels.routing import routing from triton_kernels.matmul_ogs import PrecisionConfig + from triton_kernels.matmul_ogs import update_opt_flags_constraints + + if get_gfx() == "gfx950": + # MI355X has 160KB LDS; default CDNA4 block_m=256 exceeds it. + update_opt_flags_constraints({"block_m": 128}) except (AttributeError, ImportError) as e: logger.error( "Failed to import Triton kernels. Please make sure your triton " @@ -53,9 +58,9 @@ def _swizzle_mxfp4(quant_tensor, scale): scale_layout_opts: dict[str, Any] = {} value_layout = StridedLayout if get_gfx() == "gfx950": - from triton_kernels.tensor_details.layout import GFX950MXScaleLayout + from triton_kernels.tensor_details.layout import CDNA4MXScaleLayout - scale_layout = GFX950MXScaleLayout + scale_layout = CDNA4MXScaleLayout else: scale_layout = StridedLayout