diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..c69aefffd --- /dev/null +++ b/.dockerignore @@ -0,0 +1,6 @@ +.git +__pycache__ +*.pyc +*.egg-info +build/ +dist/ diff --git a/.github/workflows/sync-upstream.yaml b/.github/workflows/sync-upstream.yaml new file mode 100644 index 000000000..2afda6a17 --- /dev/null +++ b/.github/workflows/sync-upstream.yaml @@ -0,0 +1,42 @@ +name: Sync upstream main + +on: + schedule: + # Run nightly at 06:00 UTC (midnight CST) + - cron: '0 6 * * *' + workflow_dispatch: # Allow manual trigger + +jobs: + sync: + runs-on: ubuntu-latest + steps: + - name: Checkout fork + uses: actions/checkout@v4 + with: + ref: main + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Add upstream remote + run: git remote add upstream https://github.com/ROCm/ATOM.git + + - name: Fetch upstream + run: git fetch upstream main + + - name: Check for new commits + id: check + run: | + BEHIND=$(git rev-list --count HEAD..upstream/main) + echo "behind=$BEHIND" >> "$GITHUB_OUTPUT" + echo "Fork is $BEHIND commit(s) behind upstream" + + - name: Merge upstream + if: steps.check.outputs.behind != '0' + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git merge upstream/main --no-edit + + - name: Push + if: steps.check.outputs.behind != '0' + run: git push origin main diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index e27d3af61..5e169a1b5 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -297,6 +297,7 @@ def postprocess( continue token_ids = prev_token_ids[seq.id] num_new_token = len(token_ids) + num_rejected = 0 self.update_spec_stats(num_new_token) idx = fwd_output.req_ids.index(seq.id) if is_deferred_out or self.use_spec: diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index b9e0a286e..74622651b 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -370,7 +370,19 @@ def prefill_attention_triton( if ctx.is_prefill: k_cache = k.unsqueeze(1) v_cache = v.unsqueeze(1) - block_tables = attn_metadata.fake_block_tables + # Create fake block_tables for prefill: each token is its own + # "block" (block_size=1). Shape [num_seqs, max_seqlen_k]. + batch_size = attn_metadata.cu_seqlens_k.shape[0] - 1 + max_len = attn_metadata.max_seqlen_k + block_tables = torch.zeros( + batch_size, max_len, dtype=torch.int32, device=q.device + ) + for i in range(batch_size): + s = attn_metadata.cu_seqlens_k[i].item() + e = attn_metadata.cu_seqlens_k[i + 1].item() + block_tables[i, : e - s] = torch.arange( + s, e, dtype=torch.int32, device=q.device + ) o = torch.empty_like(q) descale_shape = (attn_metadata.cu_seqlens_q.shape[0] - 1, k.shape[1]) @@ -407,7 +419,8 @@ def dispatch_backend(self, fwd_ctx: ForwardContext): ctx = fwd_ctx.context if ctx.is_prefill: - return self.prefill_attention + # Always use Triton prefill (no CK/flash_attn_varlen_func dependency) + return self.prefill_attention_triton else: if self.use_triton_attn: return self.paged_attention_triton diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 6b2452cd1..954ca8220 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -396,19 +396,22 @@ def _forward_prefill_mha( k = torch.cat((k_nope, k_rope.expand((*k_nope.shape[:-1], -1))), dim=-1) - output = flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=attn_metadata.cu_seqlens_q, - cu_seqlens_k=attn_metadata.cu_seqlens_k, - max_seqlen_q=attn_metadata.max_seqlen_q, - max_seqlen_k=attn_metadata.max_seqlen_k, - min_seqlen_q=attn_metadata.min_seqlen_q, - dropout_p=attn_metadata.dropout_p, - softmax_scale=self.scale, - causal=True, - ) + # Use PyTorch SDPA for MLA prefill attention (no CK dependency) + import torch.nn.functional as F + + cu_q = attn_metadata.cu_seqlens_q + cu_k = attn_metadata.cu_seqlens_k + num_seqs = cu_q.shape[0] - 1 + outputs = [] + for i in range(num_seqs): + qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0) + ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0) + vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0) + oi = F.scaled_dot_product_attention( + qi, ki, vi, is_causal=True, scale=self.scale + ) + outputs.append(oi.squeeze(0).transpose(0, 1)) + output = torch.cat(outputs, dim=0) return self.o_proj(output.flatten(start_dim=-2)) diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index 7fd33253b..89bbdf65e 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -180,6 +180,56 @@ def prepare_prefill(self, batch: ScheduledBatch): bs = batch.total_seqs_num_prefill sum_scheduled_tokens = batch.total_tokens_num_prefill var = self.model_runner.forward_vars + + # Prepare paged KV metadata for MLA prefill paths + # (needed by mla_prefill_fwd for bf16, unified_attention for fp8) + if batch.block_tables: + context_lens = np.asarray(batch.context_lens[:bs], dtype=np.int32) + num_blocks_per_seq = cdiv(context_lens, self.block_size) + kv_indptr = np.cumsum(num_blocks_per_seq) + sum_blocks = kv_indptr[-1] + + dst = var["kv_indices"].np + offset = 0 + for i in range(bs): + bt = batch.block_tables[i] + n = len(bt) + dst[offset : offset + n] = bt + offset += n + sum_blocks_before_converted = offset + + var["kv_indptr"].np[0] = 0 + var["kv_indptr"].np[1 : bs + 1] = kv_indptr + + attn_metadata.kv_indptr = var["kv_indptr"].copy_to_gpu(bs + 1) + attn_metadata.kv_indices = var["kv_indices"].copy_to_gpu( + sum_blocks_before_converted + ) + attn_metadata.kv_last_page_lens = var["kv_last_page_lens"].gpu[:bs] + + if self.block_ratio > 1: + kv_indices_convert_triton( + var["kv_indices"].gpu[:sum_blocks_before_converted], + var["kv_indices_converted"].gpu[:sum_blocks], + var["kv_indptr"].gpu[: bs + 1], + self.block_ratio, + self.block_size, + ) + attn_metadata.kv_indices = var["kv_indices_converted"].gpu[:sum_blocks] + + # Prepare block_tables for unified_attention (fp8 prefill) + if attn_metadata.block_tables is None: + self.prepare_block_tables(batch) + attn_metadata.block_tables = var["block_tables"].copy_to_gpu(bs) + if self.block_ratio > 1: + block_table_convert_triton( + var["block_tables"].gpu[:bs], + var["block_tables_converted"].gpu[:bs], + var["context_lens"].gpu[:bs], + self.block_ratio, + ) + attn_metadata.block_tables = var["block_tables_converted"].gpu[:bs] + if self.is_sparse and attn_metadata.max_seqlen_k > self.index_topk: if attn_metadata.block_tables is None: self.prepare_block_tables(batch) diff --git a/atom/model_ops/fused_moe_triton.py b/atom/model_ops/fused_moe_triton.py index 01d83e6e3..6324dae29 100644 --- a/atom/model_ops/fused_moe_triton.py +++ b/atom/model_ops/fused_moe_triton.py @@ -34,6 +34,11 @@ from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs from triton_kernels.routing import routing from triton_kernels.matmul_ogs import PrecisionConfig + from triton_kernels.matmul_ogs import update_opt_flags_constraints + + if get_gfx() == "gfx950": + # MI355X has 160KB LDS; default CDNA4 block_m=256 exceeds it. + update_opt_flags_constraints({"block_m": 128}) except (AttributeError, ImportError) as e: logger.error( "Failed to import Triton kernels. Please make sure your triton " @@ -53,9 +58,9 @@ def _swizzle_mxfp4(quant_tensor, scale): scale_layout_opts: dict[str, Any] = {} value_layout = StridedLayout if get_gfx() == "gfx950": - from triton_kernels.tensor_details.layout import GFX950MXScaleLayout + from triton_kernels.tensor_details.layout import CDNA4MXScaleLayout - scale_layout = GFX950MXScaleLayout + scale_layout = CDNA4MXScaleLayout else: scale_layout = StridedLayout diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 1bd3538a3..83b492a99 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -700,6 +700,10 @@ def weight_loader( elif self.quant_type == QuantType.per_Tensor: shard_offset = loaded_shard_id shard_size = 1 + else: + # per_Token and per_1x32: scale dim 0 matches output_size + shard_offset = sum(self.output_sizes[:loaded_shard_id]) + shard_size = self.output_sizes[loaded_shard_id] else: shard_offset = sum(self.output_sizes[:loaded_shard_id]) shard_size = self.output_sizes[loaded_shard_id] diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index d9dc1e34c..ccceb3e72 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -48,6 +48,206 @@ from torch import nn from transformers import PretrainedConfig +import logging + +_moe_logger = logging.getLogger(__name__) + + +def _has_ck_moe_sorting() -> bool: + """Check if CK MOE sorting kernel is available.""" + try: + import importlib + + return importlib.util.find_spec("aiter.jit.module_moe_sorting") is not None + except Exception: + return False + + +def _per_token_group_quant_fp8(x, group_size, fp8_dtype): + """Quantize input tensor to FP8 with per-token-group scaling. + + Args: + x: Input tensor of shape (M, K) in bf16/fp16. + group_size: Number of elements per quantization group. + fp8_dtype: Target FP8 dtype (e.g. torch.float8_e4m3fnuz). + + Returns: + x_fp8: Quantized tensor of shape (M, K). + scale: Dequantization scale of shape (M, K // group_size). + """ + M, K = x.shape + assert K % group_size == 0 + num_groups = K // group_size + x_float = x.float() + x_grouped = x_float.view(M, num_groups, group_size) + max_abs = x_grouped.abs().amax(dim=-1) # (M, num_groups) + fp8_max = torch.finfo(fp8_dtype).max + scale = (max_abs / fp8_max).clamp(min=1e-12) + x_scaled = x_grouped / scale.unsqueeze(-1) + x_fp8 = x_scaled.clamp(-fp8_max, fp8_max).to(fp8_dtype) + x_fp8 = x_fp8.view(M, K) + return x_fp8, scale + + +def _triton_fp8_moe( + x, + w13, + w2, + topk_weights, + topk_ids, + w13_scale, + w2_scale, + top_k, + block_quant, + quant_type, +): + """Execute FP8 MOE using AITER Triton kernels (no CK dependency). + + Two-stage pipeline: + Stage 1 (GEMM1+SiLU): x @ w13^T with SiLU gating + Stage 2 (GEMM2): intermediate @ w2^T with routing weight accumulation + + For GEMM2, we reshape the intermediate so each (token, expert_k) pair is + treated as an independent token with top_k=1, allowing correct A indexing. + """ + import triton.language as tl + from aiter.ops.triton.moe.moe_align_block_size import moe_align_block_size_triton + from aiter.ops.triton.moe.moe_op_silu_fused import fused_moe_silu + from aiter.ops.triton.moe.moe_op import fused_moe as triton_fused_moe + from aiter.ops.triton.utils.moe_config_utils import get_optimal_moe_config + + M, hidden_dim = x.shape + E = w13.shape[0] + inter_dim_2 = w13.shape[1] # 2 * inter_dim + inter_dim = inter_dim_2 // 2 + # When fused shared experts are enabled, topk_ids has M*(top_k+1) elements + actual_top_k = topk_ids.numel() // M + + if block_quant: + if quant_type == QuantType.per_1x128: + block_shape = [128, 128] + elif quant_type == QuantType.per_1x32: + block_shape = [1, 32] + else: + block_shape = None + else: + block_shape = None + + config = get_optimal_moe_config(dtype=x.dtype, use_fp8_w8a8=True, M=M) + block_size_m = config["BLOCK_SIZE_M"] + compute_type = tl.bfloat16 if x.dtype == torch.bfloat16 else tl.float16 + + # --- Stage 1: Sorting --- + max_num_tokens_padded = topk_ids.numel() + E * (block_size_m - 1) + sorted_token_ids = torch.empty( + max_num_tokens_padded, dtype=torch.int32, device=x.device + ) + sorted_token_ids.fill_(topk_ids.numel()) + max_num_m_blocks = (max_num_tokens_padded + block_size_m - 1) // block_size_m + expert_ids = torch.empty(max_num_m_blocks, dtype=torch.int32, device=x.device) + num_tokens_post_pad = torch.empty(1, dtype=torch.int32, device=x.device) + + moe_align_block_size_triton( + topk_ids, E, block_size_m, sorted_token_ids, expert_ids, num_tokens_post_pad + ) + + # --- Stage 2: GEMM1 with SiLU (x @ w13^T) --- + if block_quant and block_shape is not None: + block_k = block_shape[1] + a_fp8, a_scale = _per_token_group_quant_fp8(x, block_k, w13.dtype) + else: + a_fp8 = x + a_scale = None + + intermediate = torch.zeros( + M * actual_top_k, inter_dim, dtype=x.dtype, device=x.device + ) + + fused_moe_silu( + A=a_fp8, + B=w13, + C=intermediate, + A_scale=a_scale, + B_scale=w13_scale, + B_zp=None, + topk_weights=topk_weights, + topk_ids=topk_ids, + sorted_token_ids=sorted_token_ids, + expert_ids=expert_ids, + num_tokens_post_padded=num_tokens_post_pad, + mul_routed_weight=False, + top_k=actual_top_k, + compute_type=compute_type, + use_fp8_w8a8=True, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=block_shape, + config=config, + ) + + # --- Stage 3: GEMM2 (intermediate @ w2^T) --- + # Reshape for GEMM2: treat each (token, expert_k) as independent token + # with top_k=1 so the kernel indexes A correctly (A // top_k = A // 1 = A) + gemm2_topk_ids = topk_ids.reshape(M * actual_top_k, 1) + gemm2_topk_weights = topk_weights.reshape(M * actual_top_k, 1) + + # Re-sort for GEMM2 with the reshaped topk_ids + gemm2_max_padded = gemm2_topk_ids.numel() + E * (block_size_m - 1) + gemm2_sorted_ids = torch.empty(gemm2_max_padded, dtype=torch.int32, device=x.device) + gemm2_sorted_ids.fill_(gemm2_topk_ids.numel()) + gemm2_max_blocks = (gemm2_max_padded + block_size_m - 1) // block_size_m + gemm2_expert_ids = torch.empty(gemm2_max_blocks, dtype=torch.int32, device=x.device) + gemm2_num_pad = torch.empty(1, dtype=torch.int32, device=x.device) + + moe_align_block_size_triton( + gemm2_topk_ids, + E, + block_size_m, + gemm2_sorted_ids, + gemm2_expert_ids, + gemm2_num_pad, + ) + + # Quantize intermediate for FP8 GEMM2 + if block_quant and block_shape is not None: + block_k2 = block_shape[1] + inter_fp8, inter_scale = _per_token_group_quant_fp8( + intermediate, block_k2, w2.dtype + ) + else: + inter_fp8 = intermediate + inter_scale = None + + output = torch.zeros( + M * actual_top_k, 1, hidden_dim, dtype=x.dtype, device=x.device + ) + + triton_fused_moe( + A=inter_fp8, + B=w2, + C=output, + A_scale=inter_scale, + B_scale=w2_scale, + B_zp=None, + topk_weights=gemm2_topk_weights, + topk_ids=gemm2_topk_ids, + sorted_token_ids=gemm2_sorted_ids, + expert_ids=gemm2_expert_ids, + num_tokens_post_padded=gemm2_num_pad, + mul_routed_weight=True, + top_k=1, + compute_type=compute_type, + use_fp8_w8a8=True, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=block_shape, + config=config, + ) + + # Reduce: sum across top_k experts per token + result = output.squeeze(1).view(M, actual_top_k, hidden_dim).sum(dim=1) + return result + class FusedMoeWeightScaleSupported(Enum): """Supported quantization strategies for MoE weight scales.""" @@ -634,11 +834,16 @@ def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): self.quant_type == QuantType.per_1x128 or self.quant_type == QuantType.per_1x32 ) - self.use_triton = get_gfx().startswith("gfx94") - if self.use_triton: - from atom.model_ops.utils import has_triton_kernels - - assert has_triton_kernels(), "triton_kernels is not installed" + from atom.model_ops.utils import has_triton_kernels + + self.use_triton = get_gfx() in ("gfx942", "gfx950") and has_triton_kernels() + if not self.use_triton and not _has_ck_moe_sorting(): + if has_triton_kernels(): + self.use_triton = True + _moe_logger.info( + "CK MOE sorting not available, " + "using Triton MOE kernels for MXFP4" + ) def create_weights( self, @@ -980,6 +1185,14 @@ def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): self.block_n = 1 self.block_k = 32 + # Detect CK MOE availability; fall back to Triton MOE if unavailable + self.use_triton_moe = not _has_ck_moe_sorting() + if self.use_triton_moe: + _moe_logger.info( + "CK MOE sorting not available, using Triton MOE kernels " + "for CompressedTensors FP8" + ) + def create_weights( self, layer: torch.nn.Module, @@ -1220,16 +1433,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) # Shuffle weights for asm moe (moved from inference to load time for better performance) - if w13.dtype in [ - torch.int8, - torch.uint8, - torch.float8_e4m3fnuz, - torch.float8_e4m3fn, - ]: - from aiter.ops.shuffle import shuffle_weight + # Skip shuffle when using Triton path (Triton expects standard row-major) + if not self.use_triton_moe: + if w13.dtype in [ + torch.int8, + torch.uint8, + torch.float8_e4m3fnuz, + torch.float8_e4m3fn, + ]: + from aiter.ops.shuffle import shuffle_weight - w13.data = shuffle_weight(w13.data) - w2.data = shuffle_weight(w2.data) + w13.data = shuffle_weight(w13.data) + w2.data = shuffle_weight(w2.data) # Call parent class for any additional processing super().process_weights_after_loading(layer) @@ -1298,6 +1513,21 @@ def apply( a1_scale = getattr(layer, "w13_input_scale", None) a2_scale = getattr(layer, "w2_input_scale", None) + # Triton MOE fallback when CK is not available + if self.use_triton_moe: + return _triton_fp8_moe( + x=x, + w13=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + top_k=top_k, + block_quant=self.block_quant, + quant_type=self.quant_type, + ) + # Use modular kernel if available (for EP/DP setups) # Otherwise fall back to direct kernel call if self.fused_experts is not None: @@ -1362,6 +1592,12 @@ def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): self.need_normalize_e4m3fn_to_e4m3fnuz = ( self.quant_dtype == torch.float8_e4m3fnuz ) + # Detect CK MOE availability; fall back to Triton MOE if unavailable + self.use_triton_moe = not _has_ck_moe_sorting() + if self.use_triton_moe: + _moe_logger.info( + "CK MOE sorting not available, using Triton MOE kernels for FP8" + ) def create_weights( self, @@ -1525,7 +1761,8 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: layer.w2_weight = nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight_scale = nn.Parameter(w2_weight_scale, requires_grad=False) - shuffle_weights(layer.w13_weight, layer.w2_weight) + if not self.use_triton_moe: + shuffle_weights(layer.w13_weight, layer.w2_weight) return else: @@ -1597,7 +1834,8 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: ) start += shard_size - shuffle_weights(layer.w13_weight, layer.w2_weight) + if not self.use_triton_moe: + shuffle_weights(layer.w13_weight, layer.w2_weight) layer.w13_weight_scale = torch.nn.Parameter( max_w13_scales, requires_grad=False @@ -1647,6 +1885,20 @@ def apply( num_fused_shared_experts=layer.num_fused_shared_experts, routed_scaling_factor=layer.routed_scaling_factor, ) + # Triton MOE fallback when CK is not available + if self.use_triton_moe: + return _triton_fp8_moe( + x=x, + w13=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + top_k=top_k, + block_quant=self.block_quant, + quant_type=self.quant_type, + ) # per_Tensor not support num_local_tokens so not use mori if self.quant_type == QuantType.per_Tensor or self.fused_experts is None: return torch.ops.aiter.rocm_aiter_fused_moe( diff --git a/docker/Dockerfile b/docker/Dockerfile index 85c99daac..7d6b9d837 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -14,6 +14,8 @@ ARG AITER_COMMIT="HEAD" ARG MORI_COMMIT="b0dce4beebeb1f26c784eee17d5fd9785ee9447f" ARG PREBUILD_KERNELS=1 ARG MAX_JOBS +# Set ENABLE_CK=0 to skip CK/ASM modules for a fast Triton-only AITER build +ARG ENABLE_CK=1 RUN pip install --upgrade pip RUN pip install lm-eval[api] @@ -63,6 +65,10 @@ RUN git clone --depth=1 --branch release/internal/3.5.x https://github.com/ROCm/ MAX_JOBS=64 pip --retries=10 --default-timeout=60 install . RUN pip show triton || true +# Install triton_kernels (required for MXFP4 MoE on gfx94x) +RUN pip install --no-deps -e /triton-test/python/triton_kernels/ +RUN pip show triton-kernels || true + # Install Aiter RUN mkdir -p /app RUN pip uninstall -y aiter || true @@ -70,8 +76,11 @@ RUN git clone $AITER_REPO /app/aiter-test && \ cd /app/aiter-test && \ pip install -r requirements.txt && \ git checkout $AITER_COMMIT && \ - git submodule sync && git submodule update --init --recursive && \ - MAX_JOBS=$MAX_JOBS PREBUILD_KERNELS=$PREBUILD_KERNELS GPU_ARCHS=$GPU_ARCH_LIST python3 setup.py develop + if [ "$ENABLE_CK" != "0" ]; then \ + git submodule sync && git submodule update --init --recursive; \ + fi && \ + ENABLE_CK=$ENABLE_CK MAX_JOBS=$MAX_JOBS PREBUILD_KERNELS=$PREBUILD_KERNELS \ + GPU_ARCHS=$GPU_ARCH_LIST python3 setup.py develop RUN pip show amd-aiter || true diff --git a/docker/Dockerfile.clean b/docker/Dockerfile.clean new file mode 100644 index 000000000..4a2d4dbc6 --- /dev/null +++ b/docker/Dockerfile.clean @@ -0,0 +1,69 @@ +# Dockerfile.clean — Wheel-only ATOM/AITER build (zero source compilation) +# +# Base: rocm/dev-ubuntu-24.04:7.2-complete (Python 3.12, full ROCm runtime) +# All packages installed from pre-built wheels — no git clones, no compiles. +# +# Option A — from pre-built wheels directory: +# cd /home/pensun/ATOM +# DOCKER_BUILDKIT=1 docker build \ +# --build-context wheels=/home/pensun/dist \ +# -f docker/Dockerfile.clean -t atom:clean . +# +# Option B — multi-stage from Dockerfile.wheels builder image: +# docker build -f docker/Dockerfile.wheels -t atom:wheels . +# DOCKER_BUILDKIT=1 docker build \ +# --build-context wheels=docker-image://atom:wheels \ +# -f docker/Dockerfile.clean -t atom:clean . +# +# Run: +# docker run --rm --device=/dev/kfd --device=/dev/dri \ +# -v /data2/models:/models atom:clean bash + +ARG BASE_IMAGE="rocm/dev-ubuntu-24.04:7.2-complete" +FROM ${BASE_IMAGE} + +ENV DEBIAN_FRONTEND=noninteractive + +# ── 1. System packages (minimal — no build tools needed) ───────────── +RUN apt-get update && apt-get install -y --no-install-recommends \ + git python3-pip python3-dev \ + ibverbs-utils libpci-dev locales \ + openmpi-bin libopenmpi-dev libdw1 \ + && rm -rf /var/lib/apt/lists/* + +RUN pip3 install --break-system-packages --ignore-installed pip setuptools wheel + +# ── 2. Install all pre-built wheels ──────────────────────────────────── +# Uses bind-mount to avoid a 60+ GB COPY layer from the wheels image. +# Works with both Option A (flat directory) and Option B (docker-image://). +RUN --mount=type=bind,from=wheels,source=/,target=/mnt/wheels \ + mkdir -p /tmp/wheels \ + && find /mnt/wheels -name '*.whl' -exec cp {} /tmp/wheels/ \; \ + && ls -lhS /tmp/wheels/*.whl \ + && pip3 install --break-system-packages --no-deps \ + /tmp/wheels/torch-*.whl \ + /tmp/wheels/torchvision-*.whl \ + /tmp/wheels/torchaudio-*.whl \ + /tmp/wheels/triton-*.whl \ + /tmp/wheels/triton_kernels-*.whl \ + && pip3 install --break-system-packages \ + filelock typing-extensions sympy networkx jinja2 fsspec numpy pillow \ + && pip3 install --break-system-packages \ + /tmp/wheels/mori-*.whl \ + /tmp/wheels/flydsl-*.whl \ + && pip3 install --break-system-packages \ + /tmp/wheels/amd_aiter-*.whl \ + && rm -rf /tmp/wheels \ + && python3 -c "import torch; print(f'PyTorch {torch.__version__}, ROCm: {torch.version.hip}')" \ + && python3 -c "import triton; print(f'Triton {triton.__version__}')" \ + && python3 -c "import aiter; print('AITER OK')" \ + && python3 -c "import flydsl; print('FlyDSL OK')" \ + && pip3 show mori && echo "MORI wheel installed OK" + +# ── 3. ATOM (from build context — pure Python, instant install) ────── +COPY . /app/ATOM +RUN cd /app/ATOM && pip3 install --break-system-packages -e . \ + && python3 -c "import atom; print('ATOM OK')" + +WORKDIR /app/ATOM +CMD ["/bin/bash"] diff --git a/docker/Dockerfile.wheels b/docker/Dockerfile.wheels new file mode 100644 index 000000000..e9da1a648 --- /dev/null +++ b/docker/Dockerfile.wheels @@ -0,0 +1,159 @@ +# Dockerfile.wheels — Build/download all wheels for ATOM clean install +# +# Produces /wheels/ containing: +# torch, torchvision, torchaudio (pulled from PyTorch nightly) +# triton 3.5.x (built from ROCm/triton source) +# triton_kernels (built from ROCm/triton source) +# flydsl (built from FlyDSL source + embedded MLIR runtime) +# mori (built from MORI source) +# amd_aiter (built with ENABLE_CK=0 + pre-compiled Triton kernels) +# +# Usage (standalone — extract wheels to host): +# docker build -f docker/Dockerfile.wheels -t atom:wheels . +# docker run --rm atom:wheels tar cf - /wheels | tar xf - -C /home/pensun/dist --strip-components=1 +# +# Usage (multi-stage — pipe directly into Dockerfile.clean): +# DOCKER_BUILDKIT=1 docker build \ +# --build-context wheels=docker-image://atom:wheels \ +# -f docker/Dockerfile.clean -t atom:clean . + +ARG BASE_IMAGE="rocm/dev-ubuntu-24.04:7.2-complete" +FROM ${BASE_IMAGE} + +ARG GPU_ARCH="gfx942;gfx950" +ARG AITER_REPO="https://github.com/sunway513/aiter.git" +ARG AITER_BRANCH="feat/prebuild-triton" +ARG FLYDSL_REPO="https://github.com/ROCm/FlyDSL.git" +ARG FLYDSL_BRANCH="main" +ARG LLVM_COMMIT="04f968b02917" +ARG MORI_REPO="https://github.com/ROCm/mori.git" +ARG MORI_COMMIT="b0dce4beebeb1f26c784eee17d5fd9785ee9447f" +ARG MAX_JOBS="" +ARG PREBUILD_TRITON=1 + +ENV GPU_ARCH_LIST=${GPU_ARCH} +ENV PYTORCH_ROCM_ARCH=${GPU_ARCH} +ENV DEBIAN_FRONTEND=noninteractive + +# ── 1. System packages + build tools ──────────────────────────────── +RUN apt-get update && apt-get install -y --no-install-recommends \ + git cmake ninja-build \ + python3-pip python3-dev python3-venv \ + ibverbs-utils libpci-dev locales \ + && rm -rf /var/lib/apt/lists/* + +RUN pip3 install --break-system-packages --ignore-installed \ + pip setuptools wheel build + +RUN mkdir -p /wheels + +# ── 2. Pull PyTorch ROCm 7.2 nightly wheels ───────────────────────── +RUN pip3 download --no-deps --dest /wheels \ + torch torchvision torchaudio \ + --index-url https://download.pytorch.org/whl/nightly/rocm7.2 + +# ── 3. Build Triton 3.5.x from ROCm fork ──────────────────────────── +RUN git clone --depth=1 --branch release/internal/3.5.x \ + https://github.com/ROCm/triton.git /build/triton + +RUN cd /build/triton \ + && pip3 install --break-system-packages -r python/requirements.txt \ + && pip3 install --break-system-packages filecheck \ + && MAX_JOBS=${MAX_JOBS:-64} pip3 wheel \ + --no-build-isolation --no-deps -w /wheels . \ + && ls -lh /wheels/triton-*.whl + +# Build triton_kernels wheel +RUN cd /build/triton/python/triton_kernels \ + && pip3 wheel --no-deps -w /wheels . \ + && ls -lh /wheels/triton_kernels-*.whl + +# ── 4. Build LLVM/MLIR for FlyDSL ─────────────────────────────────── +# Blobless clone (~6 min vs ~30 min full clone). LLVM_COMMIT rarely +# changes, so this layer stays cached across most rebuilds. +RUN pip3 install --break-system-packages nanobind numpy pybind11 + +RUN git clone --filter=blob:none --no-checkout \ + https://github.com/ROCm/llvm-project.git /build/llvm-project \ + && cd /build/llvm-project \ + && git fetch origin amd-staging \ + && git checkout ${LLVM_COMMIT} + +RUN mkdir -p /build/llvm-project/buildmlir \ + && cd /build/llvm-project/buildmlir \ + && NANOBIND_DIR=$(python3 -c "import nanobind; import os; print(os.path.dirname(nanobind.__file__) + '/cmake')") \ + && cmake -G Ninja \ + -S /build/llvm-project/llvm \ + -DLLVM_ENABLE_PROJECTS="mlir;clang" \ + -DLLVM_TARGETS_TO_BUILD="X86;NVPTX;AMDGPU" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CXX_STANDARD=17 \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_INSTALL_UTILS=ON \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DPython3_EXECUTABLE=$(which python3) \ + -Dnanobind_DIR="$NANOBIND_DIR" \ + -DBUILD_SHARED_LIBS=OFF \ + && cmake --build . -j$(nproc) \ + && cmake --install . --prefix /build/llvm-project/mlir_install + +# ── 5. Install torch + triton (needed for AITER/MORI builds) ──────── +RUN pip3 install --break-system-packages --no-deps \ + /wheels/torch-*.whl /wheels/triton-3.5*.whl \ + && pip3 install --break-system-packages \ + filelock typing-extensions sympy networkx jinja2 fsspec numpy + +# ── 6. Build FlyDSL wheel ─────────────────────────────────────────── +RUN git clone --depth=1 --branch ${FLYDSL_BRANCH} ${FLYDSL_REPO} /build/FlyDSL + +RUN cd /build/FlyDSL \ + && export MLIR_PATH=/build/llvm-project/mlir_install \ + && bash flir/build.sh \ + && export FLIR_IN_BUILD_SH=1 \ + && pip3 install --break-system-packages auditwheel patchelf \ + && python3 setup.py bdist_wheel \ + && cp dist/flydsl-*.whl /wheels/ \ + && ls -lh /wheels/flydsl-*.whl + +# ── 7. Build MORI wheel ───────────────────────────────────────────── +RUN apt-get update && apt-get install -y --no-install-recommends \ + openmpi-bin libopenmpi-dev cython3 libdw1 \ + && rm -rf /var/lib/apt/lists/* + +# Patch PyTorch's Caffe2Config.cmake: the ROCm nightly wheel's config +# hard-errors when CUDA toolkit is not found, even though we only need ROCm. +# Convert the fatal error to a warning so MORI (and other torch-cmake users) +# can build against the ROCm PyTorch wheel without CUDA installed. +RUN CAFFE2_CFG=$(python3 -c "import torch, pathlib; print(pathlib.Path(torch.__file__).parent / 'share/cmake/Caffe2/Caffe2Config.cmake')") \ + && sed -i 's/message(FATAL_ERROR "Your installed Caffe2 version uses CUDA/message(WARNING "Skipped: Your installed Caffe2 version uses CUDA/' "$CAFFE2_CFG" + +RUN git clone ${MORI_REPO} /build/mori \ + && cd /build/mori \ + && git checkout ${MORI_COMMIT} \ + && grep -iv '^torch\|^triton' requirements-build.txt \ + | pip3 install --break-system-packages -r /dev/stdin \ + && git submodule update --init --recursive \ + && pip3 wheel --no-build-isolation --no-deps -w /wheels . \ + && ls -lh /wheels/mori-*.whl + +# ── 8. Build AITER wheel (ENABLE_CK=0, pre-compiled Triton kernels) ── +RUN git clone --depth=1 --branch ${AITER_BRANCH} ${AITER_REPO} /build/aiter + +# Set AITER build env for all subsequent commands in this layer +RUN cd /build/aiter \ + && pip3 install --break-system-packages -r requirements.txt \ + && export ENABLE_CK=0 PREBUILD_TRITON=${PREBUILD_TRITON} \ + PREBUILD_TRITON_ARCHS="gfx942,gfx950" \ + MAX_JOBS=${MAX_JOBS} GPU_ARCHS=${GPU_ARCH_LIST} \ + && pip3 install --break-system-packages --no-build-isolation -e . \ + && python3 -c "import aiter; print('editable install OK')" \ + && echo "install" > aiter/install_mode \ + && python3 setup.py bdist_wheel \ + && cp dist/amd_aiter-*.whl /wheels/ \ + && ls -lh /wheels/amd_aiter-*.whl + +# ── 9. Summary ────────────────────────────────────────────────────── +RUN echo "=== Wheel inventory ===" && ls -lhS /wheels/*.whl && echo "=== Done ===" + +WORKDIR /wheels +CMD ["ls", "-lhS", "/wheels/"] diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 324c10a9c..48155c046 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: MIT # Tests for atom/model_engine/scheduler.py — public API only +import numpy as np + from atom.model_engine.scheduler import Scheduler, ScheduledBatchOutput from atom.model_engine.sequence import SequenceStatus, SequenceType from atom.sampling_params import SamplingParams @@ -121,7 +123,9 @@ def _prefill(self, scheduler, seq): def _output(self, seq_id, tokens): return ScheduledBatchOutput( - token_ids={seq_id: tuple(tokens)}, draft_token_ids=None + token_ids={seq_id: tuple(tokens)}, + num_rejected=np.zeros(0, dtype=np.int32), + draft_token_ids=None, ) def test_appends_token(self, scheduler, seq_factory): @@ -166,7 +170,11 @@ def test_stop_token_ids(self, seq_factory): sched.schedule() finished = sched.postprocess( list(sched.running), - ScheduledBatchOutput(token_ids={seq.id: (99,)}, draft_token_ids=None), + ScheduledBatchOutput( + token_ids={seq.id: (99,)}, + num_rejected=np.zeros(0, dtype=np.int32), + draft_token_ids=None, + ), ) assert len(finished) == 1 assert "stop_99" in finished[0].leave_reason