YiRage (Yield Revolutionary AGile Engine) extends Mirage with comprehensive multi-backend support, enabling LLM inference optimization across diverse hardware platforms.
- Original Mirage (CMU): Superoptimizer framework for tensor programs
- YiRage Extensions (Chen Xingqiang, 2025): Multi-backend support with hardware-aware optimizations
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β YiRage Complete Backend Architecture β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Layer 1: Python API β β
β β yirage.new_kernel_graph() β UnifiedCompiler β CoreBridge β superoptimize() β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β β
β βββββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββ β
β β Layer 2: Backend Manager (C++) β β
β β BackendRegistry (thread-safe) β BackendFactory β StrategyFactory β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β β
β βββββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββ β
β β Layer 3: Search & Strategy β β
β β Hardware-aware Search β Fingerprint Verification β Performance Profiling β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β β
β βββββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββ β
β β Layer 4: Threadblock Operations β β
β β MatMul β Attention β RMSNorm β SwiGLU β Softmax β Reduce β Elementwise β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β β
β βββββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββ β
β β Layer 5: Persistent Kernel Runtime β β
β β Memory Management β Kernel Launch β Synchronization β JIT Compilation β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β β
β βββββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββ β
β β Hardware Layer β β
β β βββββββ βββββββ βββββββ βββββββββ ββββββββ ββββββββ βββββββ βββββββ ββββββββ β β
β β βCUDA β βROCm β β MPS β βAscend β β MACA β β TPU β β XPU β βFPGA β β CPU β β β
β β βNVIDAβ β AMD β βAppleβ βHuawei β βMetaX β βGoogleβ βIntelβ βXilinxββx86/ARMβ β β
β β βββββββ βββββββ βββββββ βββββββββ ββββββββ ββββββββ βββββββ βββββββ ββββββββ β β
β β βββββββββ βββββββ ββββββββ β β
β β βTriton β β NKI β β MLIR β β Compiler Backends β β
β β βOpenAI β β AWS β β LLVM β β β
β β βββββββββ βββββββ ββββββββ β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
| Backend | Hardware | Backend API | Strategy | Kernel | Threadblock | PK Runtime |
|---|---|---|---|---|---|---|
| CUDA | NVIDIA GPU | β | β | β | β | β |
| ROCm | AMD GPU | β | β | β | β | β |
| CPU | x86/ARM | β | β | β | β | β |
| MPS | Apple Silicon | β | β | β | β | β |
| Ascend | Huawei NPU | β | β | β | β | β |
| MACA | MetaX GPU | β | β | β | β | β |
| TPU | Google Cloud | β | β | β | β | β |
| XPU | Intel GPU | β | β | β | β | β |
| FPGA | Intel/Xilinx | β | β | β | β | β |
| Triton | Compiler | β | β | β | β | β |
| NKI | AWS Neuron | β | β | β | β | β |
| MLIR | Multi-target | β | β | β | β | β |
Layer 1: Python API
- Backend query and selection (
get_available_backends()) - Unified compiler interface (
UnifiedCompiler) - Core bridge to C++ (
CoreBridge) - Hardware-specific optimizers
Layer 2: Backend Manager (C++)
- BackendRegistry (singleton, thread-safe)
- Factory patterns for backends and strategies
- Automatic initialization on import
Layer 3: Search & Strategy
- Hardware-aware kernel search
- Fingerprint-based verification
- Performance profiling and modeling
Layer 4: Threadblock Operations
- Optimized LLM operators (MatMul, Attention, RMSNorm, SwiGLU)
- Hardware-specific implementations
- Code generation for Triton/NKI/MLIR
Layer 5: Persistent Kernel Runtime
- Device memory management
- Kernel launch and synchronization
- JIT compilation support
| Backend | Hardware | Key Features | Architecture |
|---|---|---|---|
| CUDA | NVIDIA GPU | Tensor Core, 32-thread Warp, cuBLAS | SM, Shared Memory |
| ROCm | AMD GPU | Matrix Core, 64-thread Wavefront, rocBLAS | GCN/CDNA, LDS |
| CPU | x86/ARM | AVX512/NEON SIMD, Cache Blocking, OpenMP | Multi-core, L1/L2/L3 |
| MPS | Apple Silicon | Metal, Threadgroup, Unified Memory | M1/M2/M3/M4 |
| Ascend | Huawei NPU | Cube Unit 16Γ16, AI Core, L1 Buffer | Ascend 910/310 |
| MACA | MetaX GPU | 64-thread Warp, CUDA-compat, Tensor Core | C500 Series |
| TPU | Google Cloud | MXU 128Γ128, BF16 Native, PJRT | TPU v2/v3/v4/v5 |
| XPU | Intel GPU | XMX 8Γ8, SYCL/oneAPI, SLM | Arc/Max/Gaudi |
| FPGA | Intel/Xilinx | DSP Blocks, Pipeline, BRAM/HBM | OpenCL Kernel |
| Triton | Compiler | Auto-tuning, Tile Fusion, MMA | PTX/HSACO |
| NKI | AWS Neuron | Tensor Engine 128Γ128, SBUF 24MB | Trainium/Inferentia |
| MLIR | Multi-target | JIT, Linalg, Pass Pipeline | LLVM/NVVM/SPIRV |
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Hardware Architecture Comparison β
ββββββββββββββ¬ββββββββββββββββββ¬ββββββββββββββββββ¬ββββββββββββββββββββββββββββββββββββ€
β Backend β Thread Model β Matrix Unit β Memory Hierarchy β
ββββββββββββββΌββββββββββββββββββΌββββββββββββββββββΌββββββββββββββββββββββββββββββββββββ€
β CUDA β 32-thread Warp β Tensor Core β Registers β Shared β L2 β HBM β
β ROCm β 64-thread Wave β Matrix Core β VGPR β LDS β L2 β HBM β
β MPS β SIMD Group β Apple GPU β Threadgroup β Device β Unified β
β Ascend β AI Core β Cube 16Γ16 β L0 β L1 β L2 β HBM β
β MACA β 64-thread Warp β Tensor Core β Shared β L2 β HBM β
β TPU β MXU Systolic β MXU 128Γ128 β VMEM β HBM β
β XPU β Xe Subgroup β XMX 8Γ8 β SLM β L3 β HBM β
β FPGA β Pipeline β DSP Block β BRAM/URAM β DDR/HBM β
ββββββββββββββ΄ββββββββββββββββββ΄ββββββββββββββββββ΄ββββββββββββββββββββββββββββββββββββ
- 60+ Optimization Methods across all 12 backends
- Automatic Configuration based on hardware capabilities
- Performance Modeling for each backend
- Code Generation for Triton/NKI/MLIR
from yirage.kernel.cuda import CUDAOptimizer, CUDAKernelConfig
config = CUDAKernelConfig()
CUDAOptimizer.optimize_grid_block_dims(1024, 1024, 1024,
compute_capability=80,
config=config)
# Auto-configured: Tensor Core, Warps, Shared Memory, Occupancyfrom yirage.kernel.mps import MPSOptimizer, MPSKernelConfig
config = MPSKernelConfig()
MPSOptimizer.optimize_for_apple_silicon(1024, 1024, 1024, config)
# Auto-detects: M1/M2/M3, GPU cores, Threadgroup sizeimport yirage as yr
# Create and optimize for Ascend NPU
graph = yr.new_kernel_graph()
X = graph.new_input(dims=(8, 4096), dtype=yr.float16)
W = graph.new_input(dims=(4096, 4096), dtype=yr.float16)
O = graph.matmul(X, W)
graph.mark_output(O)
# Optimize using Ascend backend (via BiSheng + Triton)
optimized = graph.superoptimize(backend='ascend')
# Auto-configures: AI Core blocks, Cube unit tiles, L1 bufferimport yirage as yr
# Create and optimize for MetaX MACA GPU
graph = yr.new_kernel_graph()
X = graph.new_input(dims=(8, 4096), dtype=yr.float16)
W = graph.new_input(dims=(4096, 4096), dtype=yr.float16)
O = graph.matmul(X, W)
graph.mark_output(O)
# Optimize using MACA backend (64-thread warps!)
optimized = graph.superoptimize(backend='maca')
# Auto-configures: 64-thread warps, tile sizes, shared memory
# Environment: export MACA_HOME=/opt/macaimport yirage as yr
# Create and optimize for AMD GPU
graph = yr.new_kernel_graph()
X = graph.new_input(dims=(8, 4096), dtype=yr.float16)
W = graph.new_input(dims=(4096, 4096), dtype=yr.float16)
O = graph.matmul(X, W)
graph.mark_output(O)
# Optimize using ROCm backend
optimized = graph.superoptimize(backend='rocm')
# Auto-configures: 64-thread wavefronts, LDS, Matrix Cores (MI200/MI300)
# Environment: export ROCM_PATH=/opt/rocmimport yirage as yr
# Create and optimize for Google TPU
graph = yr.new_kernel_graph()
X = graph.new_input(dims=(8, 4096), dtype=yr.bfloat16)
W = graph.new_input(dims=(4096, 4096), dtype=yr.bfloat16)
O = graph.matmul(X, W)
graph.mark_output(O)
# Optimize using TPU backend
optimized = graph.superoptimize(backend='tpu')
# Auto-configures: 128x128 MXU, BF16 native, VMEM tilingfrom yirage.pk import MLIRPKBackend
from yirage.threadblock.mlir_ops import MLIRCodeGenerator, MLIRTileConfig
# Generate MLIR for MatMul
config = MLIRTileConfig(tile_sizes=[32, 32, 32], vectorize=True)
mlir_code = MLIRCodeGenerator.generate_matmul(1024, 1024, 1024,
dtype=yr.float16, config=config)
# JIT compile and execute
backend = MLIRPKBackend(target=MLIRPKBackend.JIT_TARGET_CPU)
backend.initialize()
backend.jit_compile(mlir_code)
backend.execute("matmul", [A_ptr, B_ptr, C_ptr], 3)ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Search & Optimization Flow β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β ββββββββββββββββ ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Kernel Graph ββββββΆβ Search Engine β β
β ββββββββββββββββ β ββββββββββββββββββ βββββββββββββββββ βββββββββββββββ β β
β β β Candidate Gen ββββ Fingerprint ββββ Performance β β β
β β β (Β΅Graph Space) β β Verification β β Profiler β β β
β β ββββββββββββββββββ βββββββββββββββββ βββββββββββββββ β β
β βββββββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββ β
β β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββ β
β β Backend-Specific Strategies β β
β ββββββββββββββ¬βββββββββββββ¬βββββββββββββ¬βββββββββββββ¬βββββββββββββ¬ββββββββββββββ€ β
β β CUDA β ROCm β MPS β Ascend β MACA β TPU β β
β β TensorCore β MatrixCore β ThreadGrp β CubeUnit β 64-Warp β MXU β β
β β 32-Warp β 64-Wave β SIMD β AI Core β TensorCore β 128Γ128 β β
β ββββββββββββββΌβββββββββββββΌβββββββββββββΌβββββββββββββΌβββββββββββββΌββββββββββββββ€ β
β β XPU β FPGA β Triton β NKI β MLIR β CPU β β
β β XMX β Pipeline β AutoTune β TensorEng β LinalgOpt β SIMD/OMP β β
β β SYCL β DSP β TileFuse β SBUF β JIT/AOT β CacheBlock β β
β ββββββββββββββ΄βββββββββββββ΄βββββββββββββ΄βββββββββββββ΄βββββββββββββ΄ββββββββββββββ β
β β β
β βββββββββββββββΌββββββββββββββ β
β β Optimized Kernel β β
β β (Best Configuration) β β
β βββββββββββββββββββββββββββββ β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
- 12 Independent Search Strategies with hardware-specific optimization
- 20+ Candidate Generation Dimensions
- 15 Performance Evaluation Metrics
- Auto-tuning and performance modeling
- Code generation for compiler backends (Triton, NKI, MLIR)
git clone https://github.com/chenxingqiang/YiRage.git
cd YiRage
pip install -e . # Auto-detects CUDA/MPS/CPU# Using environment variable
YIRAGE_BACKEND=cuda pip install -e . # NVIDIA GPU
YIRAGE_BACKEND=rocm pip install -e . # AMD GPU
YIRAGE_BACKEND=mps pip install -e . # Apple Silicon
YIRAGE_BACKEND=ascend pip install -e . # Huawei NPU
YIRAGE_BACKEND=maca pip install -e . # MetaX GPU
YIRAGE_BACKEND=cpu pip install -e . # CPU only
# Multiple backends
YIRAGE_BACKEND=cuda,cpu pip install -e .π Full Ascend Guide
# Load environment
source /usr/local/Ascend/ascend-toolkit/set_env.sh
pip install torch_npu
# Install
YIRAGE_BACKEND=ascend pip install -e .π Full Installation Guide - All backends and options
import yirage as yr
# Query available backends
backends = yr.get_available_backends()
print(f"Available backends: {backends}")
# Output: ['cuda', 'cpu', 'mps'] # depends on your hardware
# Check specific backend
if yr.is_backend_available('mps'):
print("Apple Silicon GPU ready!")
# Create kernel with backend selection
mpk = yr.PersistentKernel(
mode="decode",
backend="mps", # Specify backend
fallback_backends=["cpu"], # Auto fallback
world_size=1,
mpi_rank=0,
# ... other parameters
)# CUDA optimization
from yirage.kernel.cuda import CUDAOptimizer, CUDAKernelConfig
cuda_config = CUDAKernelConfig()
CUDAOptimizer.optimize_grid_block_dims(m=1024, n=1024, k=1024,
compute_capability=80,
config=cuda_config)
# CPU optimization
from yirage.kernel.cpu import CPUOptimizer, CPUKernelConfig
cpu_config = CPUKernelConfig()
CPUOptimizer.optimize_for_cpu(m=1024, n=1024, k=1024, config=cpu_config)
# Auto-detects: SIMD type, CPU cores, cache sizes
# MPS optimization (Apple Silicon)
from yirage.kernel.mps import MPSOptimizer, MPSKernelConfig
mps_config = MPSKernelConfig()
MPSOptimizer.optimize_for_apple_silicon(m=1024, n=1024, k=1024, config=mps_config)
# Auto-detects: GPU family (M1/M2/M3), cores, memory| Benchmark | MPS (ms) | CPU (ms) |
|---|---|---|
| gated_mlp | 0.677 | 1.268 |
| rms_norm | 0.463 | 0.115 |
| lora | 0.637 | 0.590 |
| gqa | 0.554 | - |
| norm_transformer | 1.195 | - |
All benchmarks support CUDA, MPS, and CPU backends
YiRage now supports RL-guided kernel search using Ray/RLlib, enabling intelligent exploration of the kernel configuration space.
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β RL-YiRage Hierarchical Closed Loop β
β β
β Level 1: Config Policy β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β HardwareConfig (grid_dim, block_dim, forloop) βββββββββββββ β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β β β
β β β β β
β Level 2: Graph Policy (constrained by Level 1) β β β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββ β β β β
β β Β΅Graph actions ββΆ C++ Search ββΆ GPU Verify ββΆ rewardβββββ β β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββ β β β
β β² β β β
β βββββ Β΅Graph features (from C++) βββββββββββββββββββββ β β
β β β
β policy update (RLlib) β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Run integration tests (no GPU required)
python scripts/test_rl_integration.py
# Test locally (no GPU required)
python scripts/train_rl_kernel_search.py --mode local --test-episodes 10
# Train with Ray/RLlib (requires GPU for verification)
python scripts/train_rl_kernel_search.py --mode train \
--algorithm PPO \
--num-workers 8 \
--max-iterations 1000
# Search with trained policy
python scripts/train_rl_kernel_search.py --mode search \
--checkpoint /path/to/checkpoint \
--target-graph examples/matmul.jsonfrom yirage.rl import YiRageSearchEnv, EnvConfig, train_rl_search
# Create environment
env_config = EnvConfig(
target_graph_json=target_graph,
backend="cuda",
num_gpus=4,
)
# Option 1: Use as Gymnasium environment
env = YiRageSearchEnv(vars(env_config))
obs, info = env.reset()
action = env.action_space.sample()
obs, reward, done, truncated, info = env.step(action)
# Option 2: Train with RLlib
from yirage.rl import TrainingConfig
config = TrainingConfig(
algorithm="PPO",
num_workers=8,
max_iterations=500,
)
results = train_rl_search(config)from yirage.rl.search import (
HardwareConfig, SearchSpaceConstraints,
ConstrainedGraphActionSpace, HierarchicalSearchEnv
)
# Level 1: Configure hardware parameters
config = HardwareConfig(
grid_dim_x=4, grid_dim_y=2, grid_dim_z=1,
block_dim_x=128, block_dim_y=1, block_dim_z=1,
forloop_range=16,
shared_memory_size=49152
)
# Level 2: Get constraints for graph search
constraints = SearchSpaceConstraints(config)
print(f"Valid imaps: {len(constraints.valid_imaps)}")
print(f"Max operators: {constraints.max_operators}")
# Create constrained graph action space
graph_space = ConstrainedGraphActionSpace(constraints)from yirage.rl.features import MuGraphFeature, FeatureProcessor
# Features extracted from C++ layer (or simulated JSON)
features = MuGraphFeature.from_json(features_json)
print(f"Operators: {len(features.operators)}")
print(f"Graph depth: {features.graph_depth}")
# Process for neural network input
processor = FeatureProcessor()
processed = processor.process(features)
# node_features: (num_nodes, 16)
# edge_index: (2, num_edges)
# global_features: (48,)- Hierarchical Search: Level 1 (config) constrains Level 2 (Β΅Graph)
- Complete Closed Loop: RL decisions β C++ search β GPU verification β reward
- Β΅Graph Feature Extraction: Rich features from C++ layer for RL model input
- Multi-objective Reward: Balances validity, performance, efficiency, exploration
- Ray Integration: Distributed CPU workers + GPU verification
- Action Masking: Prevents invalid actions based on search state
- Model Persistence: Save/load trained policies, export to ONNX
from yirage.rl.hardware import detect_hardware, get_optimal_config
from yirage.rl.training import GRPOConfig, GRPOTrainer
# Auto-detect hardware
hardware = detect_hardware()
print(f"Detected: {hardware.backend} - {hardware.device_name}")
print(f"Peak FP16: {hardware.peak_tflops_fp16} TFLOPS")
# Get optimal config for workload
config = get_optimal_config(hardware, workload)
# Train with GRPO (supports LoRA fine-tuning)
grpo_config = GRPOConfig(
group_size=8,
learning_rate=1e-4,
use_lora=True,
lora_rank=16,
)from yirage.rl.training import FineTuningConfig, MuGraphPolicyTrainer
# Configure fine-tuning with TRL
config = FineTuningConfig(
strategy="dpo", # sft, dpo, grpo, ppo
model_name_or_path="meta-llama/Llama-2-7b-hf",
use_lora=True,
use_4bit=True, # QLoRA
lora_r=16,
)
# Train policy model
trainer = MuGraphPolicyTrainer(config)
trainer.train(train_data)
# Generate optimal configs
configs = trainer.generate_config(target_graph, hardware)Optimize any compute task on any hardware at any cluster scale with a single function call:
from yirage.rl.cluster import optimize_any_task
# Optimize with one line
result = optimize_any_task(
{"type": "attention", "batch": 32, "seq_len": 2048, "num_heads": 32},
cluster_spec={"type": "multi_node", "num_nodes": 4, "gpus_per_node": 8}
)
print(f"Strategy: {result.result.parallelism_strategy}") # e.g., "tensor_parallel_8"
print(f"Latency: {result.result.estimated_latency_ms:.2f} ms")
print(f"Throughput: {result.result.estimated_throughput_tps:.1f} samples/sec")
# Get kernel configs for YiRage search
for op_id, config in result.kernel_configs.items():
print(f"{op_id}: {config}")Device Registry (25+ Pre-defined Devices):
from yirage.rl.cluster import (
ClusterTopology, DeviceRegistry, get_device_spec, register_custom_device
)
# Create heterogeneous cluster from registry
cluster = ClusterTopology.create_from_registry([
"H100_SXM:4", # 4x NVIDIA H100
"MI300X:2", # 2x AMD MI300X
"TPUv4:2", # 2x Google TPU v4
"Ascend910B:2", # 2x Huawei Ascend
])
# Register custom hardware
register_custom_device("MyAccelerator", {
"device_type": "custom",
"compute_units": 128,
"peak_tflops_fp16": 500.0,
"memory_gb": 64.0,
"memory_bandwidth_gbps": 2000.0,
})Supported Device Types:
| Category | Devices |
|---|---|
| NVIDIA GPU | H100, A100, V100, RTX 4090, RTX 3090 |
| AMD GPU | MI300X, MI250X |
| Intel | Max 1550 (XPU) |
| TPU v4, TPU v5e | |
| Huawei | Ascend 910B, Ascend 910, Ascend 310 |
| AWS | Trainium2, Inferentia2 |
| Apple | M2 Ultra, M3 Max (MPS) |
| MetaX | C500 (MACA) |
| CPU | EPYC 9654, Xeon 8480 |
| FPGA | Alveo U280 |
| Custom | User-defined devices |
Key features:
- Any Task: MatMul, Attention, MLP, Transformer, or custom graphs
- Any Hardware: CPU, GPU, NPU, TPU, FPGA, or custom accelerators
- Any Scale: Single device to multi-node clusters
- Simulation-based: Accurate communication modeling without real cluster
- Β΅Graph Integration: Generates search space for YiRage kernel optimization
- Device Registry: 25+ pre-defined devices with full specs
- RL Closed-Loop Design
- Hierarchical Search Design
- Feature Extraction Design
- Hardware-Aware Training Design
- Universal Optimization Design
YiRage integrates the COMET framework for modeling and optimizing compound operation dataflows with explicit collective communication, based on the research paper:
"COMET: A Framework for Modeling Compound Operation Dataflows with Explicit Collectives" (Negi et al.)
- Compound Operations: Fused execution of GEMM-Softmax, GEMM-LayerNorm, Self-Attention, Gated MLP
- Explicit Collectives: AllReduce, AllGather, ReduceScatter, Broadcast with accurate cost modeling
- Data Staging Model: Ramp-up/steady-state/ramp-down phases for memory hierarchy
- Scheduling Strategies: Sequential, Pipelined, Parallel execution modes
- Energy & Latency Estimation: Detailed breakdown for optimization decisions
import yirage as yr
# Create kernel graph
graph = yr.new_kernel_graph()
# GEMM-Softmax fusion (reduces DRAM traffic by keeping intermediate on-chip)
A = graph.new_input(dims=(1024, 512), dtype=yr.float16)
B = graph.new_input(dims=(512, 1024), dtype=yr.float16)
result = graph.gemm_softmax(A, B, dim=-1)
# GEMM-LayerNorm fusion
result_ln = graph.gemm_layernorm(A, B, normalized_shape=(1024,))
# Self-Attention (FlashAttention-style fusion)
Q = graph.new_input(dims=(8, 1024, 64), dtype=yr.float16) # [H, S, D]
K = graph.new_input(dims=(8, 64, 1024), dtype=yr.float16) # [H, D, S] (transposed)
V = graph.new_input(dims=(8, 1024, 64), dtype=yr.float16) # [H, S, D]
attn_out = graph.self_attention(Q, K, V)
# Gated MLP (LLM-style with SiLU activation)
X = graph.new_input(dims=(8, 1024, 4096), dtype=yr.float16)
W_gate = graph.new_input(dims=(4096, 11008), dtype=yr.float16)
W_up = graph.new_input(dims=(4096, 11008), dtype=yr.float16)
W_down = graph.new_input(dims=(11008, 4096), dtype=yr.float16)
mlp_out = graph.gated_mlp(X, W_gate, W_up, W_down, activation="silu")
# RMSNorm + Linear (common in attention QKV projection)
norm_out = graph.rms_norm_linear(X, W_gate, normalized_shape=(4096,))
graph.mark_output(result)
optimized = graph.superoptimize(backend="cuda")from yirage.rl.cluster.simulator import (
COMETCostModel, COMETHardwareConfig,
SchedulingStrategy, MemoryLevel, CommunicationType
)
# Create cost model with hardware config
hw_config = COMETHardwareConfig(
dram_bandwidth_gbps=900.0, # HBM2e
global_buffer_bandwidth_gbps=3000.0, # On-chip L2
num_compute_units=108, # SMs on A100
peak_tflops_fp16=312.0,
)
cost_model = COMETCostModel(hw_config=hw_config)
# Estimate compound operation latency and energy
latency, energy = cost_model.estimate_compound_operation(
op_name="gemm_softmax",
input_shapes=[(2048, 1024), (1024, 2048)],
dtype_bytes=2, # FP16
num_devices=4,
strategy=SchedulingStrategy.PIPELINED,
)
print(f"Total latency: {latency.total_latency_ms:.3f} ms")
print(f" - Compute: {latency.compute_latency_ms:.3f} ms")
print(f" - Memory: {latency.total_memory_latency_ms:.3f} ms")
print(f" - Collective: {latency.collective_latency_ms:.3f} ms")
print(f"Total energy: {energy.total_energy_mj:.3f} mJ")
# Compare distributed variants (local vs distributed execution)
results = cost_model.compare_distributed_variants(
op_name="gemm_softmax",
input_shapes=[(4096, 2048), (2048, 4096)],
num_devices=8,
)
print(f"Speedup with distribution: {results['speedup']:.2f}x")from yirage.rl.cluster.simulator import CommunicationModel, CommunicationType
comm_model = CommunicationModel()
# Ring AllReduce latency (Eq. 3-4 from COMET paper)
latency_ms = comm_model.all_reduce_time_ms(
size_bytes=100 * 1024 * 1024, # 100 MB
num_devices=8,
bandwidth_gbps=200.0, # NVLink
latency_us=1.0,
algorithm="ring",
)
print(f"AllReduce latency: {latency_ms:.3f} ms")
# AllGather and ReduceScatter
gather_time = comm_model.all_gather_time_ms(
size_bytes=50 * 1024 * 1024,
num_devices=8,
bandwidth_gbps=200.0,
latency_us=1.0,
)
print(f"AllGather latency: {gather_time:.3f} ms")The cost model implements the COMET paper equations:
| Equation | Description | Formula |
|---|---|---|
| Eq. 1 | Memory Transaction | MemLat(T) = DV / BW |
| Eq. 2 | Data Staging | TotalMem = RampUp + Steady + RampDown |
| Eq. 3-4 | Ring Collective | CollLat = 2(n-1)/n Γ size / bw |
| Eq. 5-7 | Scheduling | Stall = CS + OS + CF |
Where:
- DV: Data Volume, BW: Bandwidth
- CS: Compulsory Stall (data dependency)
- OS: Optional Stall (resource blocking)
- CF: Conflict Stall (resource contention)
YiRage provides a complete search strategy for COMET compound operations:
from yirage.search import (
COMETSearchStrategy,
get_backend_config,
detect_compound_patterns,
optimize_compound_graph,
)
# Auto-detect compound patterns in a graph
op_types = ["matmul", "exp", "reduction", "div", "matmul"] # Self-attention
patterns = detect_compound_patterns(op_types)
print(f"Found {len(patterns)} compound patterns: {[p.op_type.name for p in patterns]}")
# Get backend-specific configuration (15 hardware profiles)
config = get_backend_config("cuda", "a100") # NVIDIA A100
# Or: get_backend_config("rocm", "mi300x") # AMD MI300X
# Or: get_backend_config("tpu", "v5e") # Google TPU v5e
# Or: get_backend_config("ascend", "910b") # Huawei Ascend
# Run COMET search to find optimal configuration
strategy = COMETSearchStrategy(config)
result = strategy.search(
op_types=op_types,
problem_dims={"M": 4096, "K": 4096, "N": 4096}
)
print(f"Best tile config: M={result.tile_config.tile_m}, N={result.tile_config.tile_n}")
print(f"Scheduling: {result.scheduling.name}")
print(f"Estimated latency: {result.latency_ns:.2f} ns")| Backend | Variant | DRAM BW (GB/s) | Peak TFLOPS | Tile Sizes |
|---|---|---|---|---|
| CUDA | H100 | 3350 | 989 | 64, 128, 256 |
| CUDA | A100 | 2039 | 312 | 64, 128, 256 |
| CUDA | V100 | 900 | 125 | 32, 64, 128 |
| ROCm | MI300X | 5300 | 1307 | 64, 128, 256 |
| ROCm | MI250X | 3200 | 383 | 64, 128, 256 |
| XPU | Ponte Vecchio | 3200 | 420 | 32, 64, 128, 256 |
| Ascend | 910B | 1600 | 320 | 64, 128, 256 |
| TPU | v5e | 1600 | 197 | 128, 256, 512 |
| TPU | v4 | 1200 | 275 | 128, 256, 512 |
| MACA | MXC500 | 2000 | 256 | 64, 128, 256 |
| MPS | M3 Max | 400 | 14.2 | 32, 64, 128 |
| MPS | M2 Ultra | 800 | 27.2 | 32, 64, 128 |
| CPU | Xeon | 200 | 4.0 | 32, 64, 128, 256 |
| CPU | EPYC | 460 | 5.0 | 32, 64, 128, 256 |
| FPGA | Alveo | 77 | 4.0 | 16, 32, 64, 128 |
YiRage provides production-grade distributed optimization with deep Ray integration:
| Feature | Description |
|---|---|
| C++ Binding | Direct search_partition() API via Cython for native performance |
| Object Store | ray.put() for efficient large graph data sharing |
| Placement Groups | GPU affinity with PACK/SPREAD strategies for NVLink |
| Fault Tolerance | Exponential backoff retry + checkpoint/restore |
| Collective Ops | Efficient all-reduce for gradient synchronization |
from yirage.distributed import (
RayDeepIntegration,
DeepIntegrationConfig,
GPUPlacementConfig,
RetryConfig,
RetryStrategy,
)
# Configure distributed optimization
config = DeepIntegrationConfig(
num_workers=8,
gpu_placement=GPUPlacementConfig(
gpus_per_worker=1,
strategy="PACK", # NVLink locality
),
retry=RetryConfig(
strategy=RetryStrategy.EXPONENTIAL,
max_retries=5,
),
use_object_store=True,
)
# Create engine and optimize
engine = RayDeepIntegration(config)
result = engine.optimize(
graph={"type": "matmul", "input_shapes": [[1024, 2048], [2048, 4096]]},
search_space={"grid_dims": [(1,1,1), (2,1,1), (4,1,1)], "block_dims": [(128,1,1), (256,1,1)]},
)
print(f"Best latency: {result['best_latency_ms']:.3f} ms")
print(f"Workers used: {result['num_workers']}")# Distributed gradient synchronization
gradients = [{"layer1": 0.1}, {"layer1": 0.3}, {"layer1": 0.2}, {"layer1": 0.4}]
reduced = engine.all_reduce_gradients(gradients, reduce_op="mean")
# reduced["layer1"] = 0.25python examples/cluster/deep_ray_integration_demo.py- Quick Start - Get started in 5 minutes
- API Reference - Complete API documentation
- Backend Guide - Backend usage and configuration
- Architecture Design - System design
| Platform | Guide | Description |
|---|---|---|
| Huawei Ascend NPU | Installation Guide | Complete setup, build, and test instructions |
| Huawei Ascend NPU | Quick Start | Quick API usage examples |
| MetaX MACA GPU | Quick Start | MetaX GPU integration π |
- Contributing - Contribution guidelines
# MPS backend (Apple Silicon)
python benchmark/baselines/pytorch/gated_mlp.py -b 8 --backend mps
# CUDA backend (NVIDIA GPU)
python benchmark/baselines/pytorch/gated_mlp.py -b 8 --backend cuda
# CPU backend
python benchmark/baselines/pytorch/gated_mlp.py -b 8 --backend cpu
# Ascend backend (Huawei NPU) - requires CANN + torch_npu
python benchmark/baselines/pytorch/gated_mlp.py -b 8 --backend ascend
# MACA backend (MetaX GPU) - requires MACA SDK
python benchmark/baselines/pytorch/gated_mlp.py -b 8 --backend macaimport yirage as yr
# Method 1: Direct specification
mpk = yr.PersistentKernel(backend="mps", ...)
# Method 2: With fallback
mpk = yr.PersistentKernel(
backend="cuda",
fallback_backends=["mps", "cpu"], # Auto fallback
...
)
# Method 3: Query and select
backends = yr.get_available_backends()
best_backend = backends[0] # Use first availableWe welcome contributions! Please see CONTRIBUTING.md for guidelines.
- Implement
BackendInterface - Create
{Backend}KernelConfig - Implement
{Backend}Optimizer - Create
{Backend}SearchStrategy(optional) - Update CMake configuration
See Ascend Implementation Guide for a complete example.
YiRage is licensed under the Apache License 2.0.
Copyright:
- YiRage Multi-Backend Extensions: Copyright 2025 Chen Xingqiang
- Original Mirage: Copyright 2023-2024 Carnegie Mellon University
See LICENSE, NOTICE, and ATTRIBUTION for details.
@software{yirage2025,
title={YiRage: Yield Revolutionary AGile Engine for Multi-Backend LLM Inference},
author={Chen, Xingqiang},
year={2025},
note={A derivative work based on Mirage},
url={https://github.com/chenxingqiang/YiRage}
}
@inproceedings{wu2024mirage,
title={Mirage: A Multi-Level Superoptimizer for Tensor Programs},
author={Mengdi Wu and Xinhao Cheng and Shengyu Liu and others},
booktitle={OSDI 2025},
year={2025}
}YiRage builds upon the excellent work of the Mirage team at Carnegie Mellon University.
- Issues: GitHub Issues
- Author: Chen Xingqiang
- Email: joy6677@outlook.com
YiRage - Yielding Maximum Performance Across All Hardware π
Copyright 2025 Chen Xingqiang | Based on Mirage (CMU) | Apache License 2.0