diff --git a/include/tritonblas/matmul.py b/include/tritonblas/matmul.py index 43bc779..43dd8e2 100755 --- a/include/tritonblas/matmul.py +++ b/include/tritonblas/matmul.py @@ -1,3 +1,22 @@ +"""! +@file matmul.py +@brief High-level matrix multiplication interface with persistent and Stream-K execution modes. + +This module provides the main user-facing API for matrix multiplication operations +using Triton kernels optimized for AMD GPU hardware. It supports both persistent +and Stream-K execution strategies with automatic heuristic-based optimization. + +Key features: +- Automatic tile size and grid optimization via MatmulHeuristicResult +- Support for persistent and Stream-K execution modes +- LRU caching of heuristic results for repeated problem sizes +- Pre-allocated global buffers for Stream-K synchronization +- Support for various data types (FP32, FP16, BF16, FP8, etc.) + +@author TritonBLAS Development Team +@date 2024 +""" + import torch import triton import random @@ -8,20 +27,33 @@ from .origami import MatmulHeuristicResult from typing import Dict, Tuple, Optional +#! Cache for tensor storage to avoid repeated allocations _tensor_cache = {} + +#! Current CUDA device index for hardware queries current_device_index = torch.cuda.current_device() + +#! CUDA device properties for the current device current_device = torch.cuda.get_device_properties(current_device_index) + +#! Maximum number of streaming multiprocessors on current device MAX_SMS = current_device.multi_processor_count -# TODO: 256x256 for fp16/bf16, need adjust for fp8/fp4 + +#! Maximum block size for pre-allocated buffers (256x256 for fp16/bf16) +#! TODO: Adjust for fp8/fp4 data types MAX_BLOCK_SIZE = 65536 -# Global pre-allocated buffers +#! Pre-allocated global synchronization locks for Stream-K execution +#! Used to coordinate partial tile accumulation across workgroups _global_locks = torch.empty(MAX_SMS, device="cuda", dtype=torch.uint8) + +#! Pre-allocated global partial results buffer for Stream-K execution +#! Stores intermediate accumulation results during multi-workgroup tiles _global_P = torch.empty(MAX_SMS, MAX_BLOCK_SIZE, device="cuda", dtype=torch.float32) -# Function will behave like an LRU-Cache of heuristic results -# Saves several microseconds for previously seen problems by not rerunning the heuristic unnecessarily +#! LRU cache decorator for heuristic result caching +#! Saves several microseconds for previously seen problems by not rerunning optimization @functools.lru_cache(maxsize=1024) def _make_matmul_selector( M: int, @@ -31,11 +63,83 @@ def _make_matmul_selector( b_dtype: torch.dtype, c_dtype: torch.dtype, ): + """! + @brief Create and cache a matrix multiplication heuristic selector. + + Factory function that creates MatmulHeuristicResult instances with LRU caching + to avoid recomputing optimization parameters for repeated problem configurations. + This provides significant performance benefits for applications with recurring + matrix sizes and data types. + + @param M (int): Number of rows in matrix A and output matrix C + @param N (int): Number of columns in matrix B and output matrix C + @param K (int): Number of columns in matrix A and rows in matrix B + @param a_dtype (torch.dtype): Data type of input matrix A + @param b_dtype (torch.dtype): Data type of input matrix B + @param c_dtype (torch.dtype): Data type of output matrix C + + @return MatmulHeuristicResult: Optimized configuration selector for the problem + + @details + The LRU cache stores up to 1024 unique problem configurations, keyed by + the combination of matrix dimensions and data types. Cache hits avoid: + - Hardware detection and analysis + - Matrix instruction dimension inference + - Tile size optimization via Origami performance modeling + - Grid size computation + + Cache effectiveness depends on application patterns: + - Training workloads: High hit rates due to repeated layer sizes + - Inference workloads: High hit rates for batch processing + - Benchmarking: Perfect hit rates for repeated measurements + + @note The cache key includes all optimization-relevant parameters. + Changes to any parameter will result in cache misses and recomputation. + + @see MatmulHeuristicResult for optimization algorithm details + """ # Run Heuristic Results (Only if key has not been seen before) return MatmulHeuristicResult(M, N, K, a_dtype, b_dtype, c_dtype) def persistent_matmul_lt(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, selector): + """! + @brief Execute matrix multiplication using persistent (data-parallel) execution mode. + + Performs C = A @ B using a persistent execution strategy where each workgroup + processes exactly one tile of the output matrix. This mode provides predictable + performance and is optimal for problems with good load balance. + + @param a (torch.Tensor): Input matrix A with shape (M, K) + @param b (torch.Tensor): Input matrix B with shape (K, N) + @param c (torch.Tensor): Output matrix C with shape (M, N) - modified in-place + @param selector (MatmulHeuristicResult): Pre-computed optimization configuration + + @return torch.Tensor: Reference to the modified output tensor c + + @details + Persistent execution characteristics: + - One-to-one mapping between workgroups and output tiles + - Grid size equals total number of tiles (M/BLK_M * N/BLK_N) + - Optimal for well-balanced problems with uniform tile work + - Lower synchronization overhead compared to Stream-K + - Predictable execution time and resource usage + + Kernel configuration: + - 2 pipeline stages for memory latency hiding + - 8 warps per workgroup for occupancy optimization + - Matrix instruction size of 16x16 for AMD CDNA architectures + - K-dimension packing factor of 1 + + @pre Matrix A inner dimension must equal matrix B outer dimension (a.shape[1] == b.shape[0]) + @pre All tensors must be on the same CUDA device + @pre Selector must be configured for the same problem dimensions + + @throws AssertionError: If matrix dimensions are incompatible + + @see streamk_matmul_lt() for Stream-K execution alternative + @see MatmulHeuristicResult for optimization configuration details + """ assert a.shape[1] == b.shape[0], "Incompatible Dimensions" M, K = a.shape _, N = b.shape @@ -48,42 +152,42 @@ def persistent_matmul_lt(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, sele total_programs = total_tiles even_k = K % BLK_K == 0 - # TODO: Separate these configs. - # basica configs for most of compute bound sizes - # TODO: set these values analytically? - num_stages = 2 - num_warps = 8 - waves_per_eu = 0 - mfmaInstrSize = 16 - kpack = 1 + # Kernel execution parameters optimized for most compute-bound workloads + # TODO: Separate these configs for different problem characteristics + num_stages = 2 # Pipeline stages for memory/compute overlap + num_warps = 8 # Warps per workgroup for occupancy + waves_per_eu = 0 # Let hardware scheduler decide + mfmaInstrSize = 16 # Matrix instruction size (16x16) + kpack = 1 # K-dimension packing factor - # Run in Data-parallel mode. + # Configure for data-parallel execution grids = total_tiles - # TODO: Support other matmul algs. + # Execute persistent matrix multiplication kernel + # TODO: Support bias addition and other GEMM variants kk = persistent_matmul[(grids,)]( a, b, c, - None, # TODO: Enable bias. + None, # TODO: Enable bias tensor M, N, K, - a.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - 0, # TODO: Enable bias stride. - stride_ak=a.stride(1), - stride_bk=b.stride(0), + a.stride(0), # Stride for A rows + b.stride(1), # Stride for B columns + c.stride(0), # Stride for C rows + c.stride(1), # Stride for C columns + 0, # TODO: Enable bias stride + stride_ak=a.stride(1), # Stride for A columns (K dimension) + stride_bk=b.stride(0), # Stride for B rows (K dimension) BLOCK_SIZE_M=BLK_M, BLOCK_SIZE_N=BLK_N, BLOCK_SIZE_K=BLK_K, GROUP_SIZE_M=gsize_m, NUM_SMS=total_programs, - NUM_XCDS=8, - BIAS=False, - EVEN_K=even_k, + NUM_XCDS=8, # Number of shader arrays (hardware-specific) + BIAS=False, # TODO: Enable bias support + EVEN_K=even_k, # K dimension divisibility optimization num_stages=num_stages, num_warps=num_warps, waves_per_eu=waves_per_eu, @@ -97,6 +201,51 @@ def persistent_matmul_lt(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, sele def streamk_matmul_lt( a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, selector, sk_grid: Optional[int] = None ): + """! + @brief Execute matrix multiplication using Stream-K execution mode for load balancing. + + Performs C = A @ B using Stream-K execution strategy where workgroups can process + multiple tiles or fractions of tiles to achieve better load balancing. This mode + is particularly effective for problems with irregular tile boundaries or when + the number of tiles doesn't evenly divide the number of compute units. + + @param a (torch.Tensor): Input matrix A with shape (M, K) + @param b (torch.Tensor): Input matrix B with shape (K, N) + @param c (torch.Tensor): Output matrix C with shape (M, N) - modified in-place + @param selector (MatmulHeuristicResult): Pre-computed optimization configuration + @param sk_grid (Optional[int]): Override grid size for Stream-K execution. + If None, uses selector's optimized grid size. + + @return torch.Tensor: Reference to the modified output tensor c + + @details + Stream-K execution characteristics: + - Workgroups can process partial tiles for load balancing + - Grid size optimized via dynamic algorithms in compute_sk_grid() + - Requires synchronization buffers for partial result accumulation + - Better utilization when tiles don't evenly distribute across CUs + - Higher overhead but improved performance for irregular problems + + Synchronization mechanism: + - Uses pre-allocated global buffers when possible for performance + - Falls back to dynamic allocation for large grids/blocks + - Atomic operations coordinate partial tile accumulation + - Lock-based synchronization ensures correctness + + Buffer management: + - Reuses global buffers (_global_locks, _global_P) when size permits + - Allocates temporary buffers for oversized configurations + - Optimized buffer zeroing and initialization + + @pre Matrix A inner dimension must equal matrix B outer dimension (a.shape[1] == b.shape[0]) + @pre All tensors must be on the same CUDA device + @pre Selector must be configured for the same problem dimensions + + @throws AssertionError: If matrix dimensions are incompatible + + @see persistent_matmul_lt() for data-parallel execution alternative + @see MatmulHeuristicResult.compute_sk_grid() for grid optimization algorithm + """ assert a.shape[1] == b.shape[0], "Incompatible Dimensions" M, K = a.shape _, N = b.shape @@ -109,61 +258,65 @@ def streamk_matmul_lt( even_k = K % BLK_K == 0 ## - # Grid Size + # Grid Size Configuration ## total_programs_streamk = selector.get_grid() - if total_programs_streamk > 0: # Stream-K + if total_programs_streamk > 0: # Stream-K mode enabled total_tiles_streamk = total_tiles % total_programs_streamk - else: # all tiles are computed using classical blocking + else: # Fallback to classical blocking (persistent mode) total_tiles_streamk = 0 - num_stages = 2 - num_warps = 8 - waves_per_eu = 0 - mfmaInstrSize = 16 - kpack = 1 + # Kernel execution parameters + num_stages = 2 # Pipeline stages for memory/compute overlap + num_warps = 8 # Warps per workgroup for occupancy + waves_per_eu = 0 # Let hardware scheduler decide + mfmaInstrSize = 16 # Matrix instruction size (16x16) + kpack = 1 # K-dimension packing factor + # Override grid size if explicitly specified if sk_grid is not None: total_programs_streamk = sk_grid grids = total_programs_streamk block_size = BLK_M * BLK_N - # Use global buffers with optimized zeroing + # Efficient buffer management with pre-allocated globals when possible if grids <= MAX_SMS and block_size <= MAX_BLOCK_SIZE: - locks = _global_locks[:grids] - P = _global_P[:grids, :block_size] + locks = _global_locks[:grids] # Synchronization locks + P = _global_P[:grids, :block_size] # Partial results buffer else: + # Dynamic allocation for oversized configurations locks = torch.empty(grids, device="cuda", dtype=torch.uint8) P = torch.empty(grids, block_size, device="cuda", dtype=torch.float32) + # Execute Stream-K matrix multiplication kernel kk = streamk_matmul[(grids,)]( a, b, c, - None, # TODO: Enable bias. - P, - locks, + None, # TODO: Enable bias tensor + P, # Partial results accumulation buffer + locks, # Synchronization locks for coordination M, N, K, - a.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - 0, # TODO: Enable bias stride. - stride_ak=a.stride(1), - stride_bk=b.stride(0), + a.stride(0), # Stride for A rows + b.stride(1), # Stride for B columns + c.stride(0), # Stride for C rows + c.stride(1), # Stride for C columns + 0, # TODO: Enable bias stride + stride_ak=a.stride(1), # Stride for A columns (K dimension) + stride_bk=b.stride(0), # Stride for B rows (K dimension) BLOCK_SIZE_M=BLK_M, BLOCK_SIZE_N=BLK_N, BLOCK_SIZE_K=BLK_K, GROUP_SIZE_M=gsize_m, NUM_SMS=grids, - NUM_XCDS=8, - STREAMK_TILES=total_tiles_streamk, - BIAS=False, - EVEN_K=even_k, + NUM_XCDS=8, # Number of shader arrays (hardware-specific) + STREAMK_TILES=total_tiles_streamk, # Number of Stream-K tiles + BIAS=False, # TODO: Enable bias support + EVEN_K=even_k, # K dimension divisibility optimization num_stages=num_stages, num_warps=num_warps, waves_per_eu=waves_per_eu, @@ -177,6 +330,43 @@ def streamk_matmul_lt( def matmul_lt( a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, selector, enable_streamk=False ): + """! + @brief Low-level matrix multiplication interface with pre-computed selector. + + Performs C = A @ B using a pre-computed MatmulHeuristicResult selector to avoid + optimization overhead. Provides choice between persistent and Stream-K execution + modes for different performance characteristics. + + @param a (torch.Tensor): Input matrix A with shape (M, K) + @param b (torch.Tensor): Input matrix B with shape (K, N) + @param c (torch.Tensor): Output matrix C with shape (M, N) - modified in-place + @param selector (MatmulHeuristicResult): Pre-computed optimization configuration + @param enable_streamk (bool, optional): Enable Stream-K execution mode for load balancing. + Default is False (uses persistent mode). + + @return torch.Tensor: Reference to the modified output tensor c + + @details + This is a low-level interface that requires manual selector management but provides + maximum control over execution strategy. Useful for scenarios where: + - Multiple operations share the same matrix dimensions and data types + - Custom grid size or tile selection is required + - Benchmark or profiling requires consistent configurations + + Execution mode selection: + - Persistent mode (enable_streamk=False): Optimal for well-balanced problems + - Stream-K mode (enable_streamk=True): Better for irregular or small problems + + @pre Matrix A inner dimension must equal matrix B outer dimension (a.shape[1] == b.shape[0]) + @pre All tensors must be on the same CUDA device + @pre Selector must be configured for compatible problem dimensions + + @throws AssertionError: If matrix dimensions are incompatible + + @see matmul() for high-level interface with automatic selector creation + @see persistent_matmul_lt() for persistent execution details + @see streamk_matmul_lt() for Stream-K execution details + """ assert a.shape[1] == b.shape[0], "Incompatible Dimensions" if enable_streamk: @@ -192,11 +382,79 @@ def matmul( enable_streamk=False, sk_grid=None, ): + """! + @brief High-level matrix multiplication interface with automatic optimization. + + Performs C = A @ B with automatic heuristic-based optimization for tile sizes, + grid configuration, and execution strategy. This is the primary user-facing API + for matrix multiplication operations in TritonBLAS. + + @param a (torch.Tensor): Input matrix A with shape (M, K) + @param b (torch.Tensor): Input matrix B with shape (K, N) + @param c (torch.Tensor): Output matrix C with shape (M, N) - modified in-place + @param enable_streamk (bool, optional): Enable Stream-K execution mode for load balancing. + Default is False (uses persistent mode). + @param sk_grid (Optional[int]): Override Stream-K grid size. Only used when + enable_streamk=True. If None, uses optimized grid size. + + @return torch.Tensor: Reference to the modified output tensor c + + @details + This function provides a complete automated matrix multiplication pipeline: + 1. Extract matrix dimensions and data types + 2. Create/retrieve cached MatmulHeuristicResult selector + 3. Dispatch to appropriate execution kernel (persistent or Stream-K) + 4. Return modified output tensor + + Automatic optimizations include: + - Hardware-aware tile size selection via Origami performance modeling + - Matrix instruction dimension inference based on data types + - Grid size optimization for load balancing + - LRU caching of optimization results for repeated problem sizes + + Execution modes: + - Persistent mode (default): One workgroup per output tile, optimal for balanced problems + - Stream-K mode: Dynamic load balancing with partial tile processing + + Performance considerations: + - First call for new (M,N,K,dtype) combination incurs optimization overhead (~microseconds) + - Subsequent calls with same parameters use cached results (near-zero overhead) + - Memory allocation is handled internally with pre-allocated buffers when possible + + @pre Matrix A inner dimension must equal matrix B outer dimension (a.shape[1] == b.shape[0]) + @pre All tensors must be on the same CUDA device and contiguous in memory + @pre Tensors must have compatible data types supported by the hardware + + @throws AssertionError: If matrix dimensions are incompatible + + @see matmul_lt() for low-level interface with manual selector management + @see MatmulHeuristicResult for optimization algorithm details + + @par Example Usage: + @code{.py} + import torch + import tritonblas + + # Create input matrices + A = torch.randn(1024, 512, device="cuda", dtype=torch.float16) + B = torch.randn(512, 1024, device="cuda", dtype=torch.float16) + C = torch.zeros(1024, 1024, device="cuda", dtype=torch.float16) + + # Perform matrix multiplication + result = tritonblas.matmul(A, B, C) + + # Enable Stream-K for better load balancing + result = tritonblas.matmul(A, B, C, enable_streamk=True) + @endcode + """ assert a.shape[1] == b.shape[0], "Incompatible Dimensions" M, K = a.shape _, N = b.shape + # Create or retrieve cached optimization selector selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, c.dtype) + + # Dispatch to appropriate execution mode if enable_streamk: return streamk_matmul_lt(a, b, c, selector, sk_grid=sk_grid) else: diff --git a/include/tritonblas/origami.py b/include/tritonblas/origami.py index f858af7..bbd4651 100755 --- a/include/tritonblas/origami.py +++ b/include/tritonblas/origami.py @@ -1,24 +1,69 @@ +"""! +@file origami.py +@brief Matrix multiplication heuristic optimization module using Origami hardware abstraction. + +This module provides the MatmulHeuristicResult class which implements intelligent +tile selection and grid computation for optimal matrix multiplication performance +on AMD GPU hardware. It leverages the Origami library for hardware-aware optimization. + +@author TritonBLAS Development Team +@date 2024 +""" + import torch import itertools from math import ceil import origami -# https://docs.pytorch.org/docs/stable/tensors.html +#! Mapping from PyTorch data types to their string representations. +#! Used for interfacing with the Origami library's datatype system. +#! @see https://docs.pytorch.org/docs/stable/tensors.html dtype_to_str = { - torch.float32: "f32", - torch.complex64: "c32", - torch.complex128: "c64", - torch.float64: "f64", - torch.float16: "f16", - torch.int32: "i32", - torch.bfloat16: "bf16", - torch.int8: "i8", - torch.float8_e5m2: "f8", - torch.float8_e4m3fn: "f8", + torch.float32: "f32", #!< 32-bit floating point + torch.complex64: "c32", #!< 32-bit complex (single precision) + torch.complex128: "c64", #!< 64-bit complex (double precision) + torch.float64: "f64", #!< 64-bit floating point (double precision) + torch.float16: "f16", #!< 16-bit floating point (half precision) + torch.int32: "i32", #!< 32-bit signed integer + torch.bfloat16: "bf16", #!< 16-bit brain floating point + torch.int8: "i8", #!< 8-bit signed integer + torch.float8_e5m2: "f8", #!< 8-bit floating point (5-bit exponent, 2-bit mantissa) + torch.float8_e4m3fn: "f8", #!< 8-bit floating point (4-bit exponent, 3-bit mantissa, finite/NaN) } class MatmulHeuristicResult: + """! + @brief Heuristic-based matrix multiplication configuration optimizer. + + This class analyzes matrix multiplication parameters and hardware characteristics + to determine optimal tile sizes, grid configurations, and execution strategies + for maximum performance on AMD GPU architectures. + + The heuristic considers: + - Matrix dimensions (M, N, K) + - Data types and their bit widths + - Hardware specifications (CU count, matrix instruction support) + - Memory hierarchy optimization + - Stream-K vs persistent execution modes + + @details + The optimization process involves: + 1. Hardware detection and matrix instruction dimension inference + 2. Valid tile size generation based on hardware constraints + 3. Performance modeling using the Origami library + 4. Grid size computation for optimal resource utilization + 5. Stream-K grid optimization for load balancing + + Supported AMD GPU architectures: + - gfx950: 256 CUs, supports FP32/FP16/BF16/FP8 + - gfx942 (MI300X): 304 CUs, supports FP32/FP16/BF16/FP8 + - gfx942 (MI300A): 228 CUs, supports FP32/FP16/BF16/FP8 + - gfx908 (MI200): 104 CUs, supports FP32/FP16/BF16 + + @note This class is typically instantiated automatically by the matmul functions + and cached for performance via LRU caching mechanisms. + """ def __init__( self, m, @@ -31,6 +76,43 @@ def __init__( mx_block_size=0, # Number of MX datatype elements that share a scale streamk=True, ): + """! + @brief Initialize the matrix multiplication heuristic optimizer. + + Analyzes the given matrix multiplication problem and hardware to determine + optimal execution parameters including tile sizes, grid configuration, + and execution strategy. + + @param m (int): Number of rows in matrix A and output matrix C + @param n (int): Number of columns in matrix B and output matrix C + @param k (int): Number of columns in matrix A and rows in matrix B (inner dimension) + @param a_dtype (torch.dtype): Data type of input matrix A + @param b_dtype (torch.dtype): Data type of input matrix B + @param c_dtype (torch.dtype): Data type of output matrix C + @param MI_dim (list[int], optional): Matrix instruction dimensions [M, N, K]. + If None, dimensions are inferred from hardware and data types. + @param mx_block_size (int, optional): Number of MX datatype elements sharing a scale factor. + Used for mixed-precision optimizations. Default is 0 (disabled). + @param streamk (bool, optional): Enable Stream-K execution mode for better load balancing. + Default is True. When False, uses persistent execution mode. + + @details + Initialization process: + 1. Store matrix dimensions and data types + 2. Query hardware information via Origami + 3. Compute element sizes and infer matrix instruction dimensions + 4. Generate optimal tile configuration using performance modeling + 5. Compute grid size based on execution mode (Stream-K or persistent) + + @note The constructor performs significant computation including hardware + queries and performance modeling. Results should be cached when possible. + + @throws ValueError: If unsupported hardware architecture is detected or + if data types are incompatible with detected hardware. + + @see _infer_matrix_instruction_dimensions() for hardware compatibility details + @see compute_sk_grid() for Stream-K grid optimization algorithm + """ # Set Instance Variables self.m = m @@ -159,6 +241,34 @@ def _infer_matrix_instruction_dimensions(self, element_size_A, element_size_B): return MI_dim def _get_valid_tiles(self): + """! + @brief Generate all valid tile size combinations for the current hardware configuration. + + Creates a Cartesian product of possible tile dimensions (M, N, K) combined with + matrix instruction dimensions and kernel occupancy options. This provides the + search space for tile optimization. + + @return list[tuple]: List of tuples containing: + - Block size M (from block_mn_range) + - Block size N (from block_mn_range) + - Block size K (from block_k_range) + - Matrix instruction M dimension + - Matrix instruction N dimension + - Matrix instruction K dimension + - Kernel occupancy (workgroups per CU) + + @details + The tile dimensions are constrained by: + - Hardware matrix instruction capabilities (MI_dim) + - Memory hierarchy efficiency (block_mn_range, block_k_range) + - Compute unit occupancy requirements (kernel_occupancy) + + @note This method generates the full search space for tile optimization. + The actual selection is performed by _get_best_tile_size(). + + @see _get_best_tile_size() for tile selection algorithm + @see _infer_matrix_instruction_dimensions() for MI_dim computation + """ return list( itertools.product( self.block_mn_range, @@ -172,6 +282,32 @@ def _get_valid_tiles(self): ) def _get_gsize_m(self, BLK_M, BLK_N, BLK_K): + """! + @brief Determine optimal workgroup mapping (GROUP_SIZE_M) for given tile dimensions. + + Uses the Origami library to select the best workgroup mapping strategy that + minimizes memory bank conflicts and maximizes compute unit utilization + for the specified tile configuration. + + @param BLK_M (int): Tile size in the M dimension (rows of A, rows of C) + @param BLK_N (int): Tile size in the N dimension (cols of B, cols of C) + @param BLK_K (int): Tile size in the K dimension (cols of A, rows of B) + + @return int: Optimal GROUP_SIZE_M value for workgroup scheduling + + @details + The workgroup mapping affects: + - Memory access patterns and bank conflicts + - Load balancing across compute units + - Cache locality and reuse + - Overall throughput and efficiency + + The method evaluates multiple GROUP_SIZE_M candidates [1, 2, 4, 6, 8] + using hardware-aware performance modeling with L2 cache hit rate assumptions. + + @note This uses Origami's select_best_wgm function which performs detailed + hardware modeling including memory hierarchy analysis. + """ results = origami.select_best_wgm( self.m, # M self.n, # N @@ -193,6 +329,42 @@ def _get_gsize_m(self, BLK_M, BLK_N, BLK_K): return results[1] def _get_best_tile_size(self): + """! + @brief Select optimal tile dimensions using hardware-aware performance modeling. + + Evaluates all valid tile configurations using the Origami library's performance + modeling to determine the tile sizes that maximize throughput for the given + matrix multiplication problem and hardware configuration. + + @return tuple[int, int, int]: Optimal tile dimensions (BLK_M, BLK_N, BLK_K) + + @details + The selection process: + 1. Generate all valid tile combinations via _get_valid_tiles() + 2. Use Origami's select_best_macro_tile_size() for performance modeling + 3. Apply hardware-specific heuristics for fine-tuning + 4. Return the tile configuration with highest predicted performance + + Performance modeling considers: + - Memory bandwidth utilization + - Compute unit occupancy and efficiency + - Cache hit rates and memory hierarchy + - Matrix instruction utilization + - Load balancing characteristics + + Hardware-specific adjustments: + - MI300X (304 CUs): Applies heuristics for 256x256 tiles to balance + performance vs occupancy based on empirical observations + + @note The modeling assumes: + - Transposed A matrix (transA=True) + - Non-transposed B matrix (transB=False) + - L2 cache hit rate of 80% (0.8) + - Default workgroup mapping of 6 + + @see _get_valid_tiles() for tile generation + @see origami.select_best_macro_tile_size() for performance modeling details + """ valid_tiles = self._get_valid_tiles() results = origami.select_best_macro_tile_size( self.m, # M @@ -216,8 +388,9 @@ def _get_best_tile_size(self): best_result = results[0] - # Heuristic weightin to different tiles - if self.hardware.N_CU == 304: + # Heuristic weighting for specific hardware configurations + if self.hardware.N_CU == 304: # MI300X + # For 256x256 tiles, consider alternative if performance is close if best_result[1] == 256 and best_result[2] == 256: if results[0][0] * 1.00 > results[1][0]: best_result = results[1] @@ -225,14 +398,65 @@ def _get_best_tile_size(self): return (best_result[1], best_result[2], best_result[3]) def _prepare_config(self): + """! + @brief Prepare complete configuration tuple with optimal tile sizes and workgroup mapping. + + Combines the results of tile size optimization and workgroup mapping selection + into a complete configuration tuple that can be used by the matrix multiplication + kernels. + + @return tuple[int, int, int, int]: Configuration tuple containing: + - BLK_M: Optimal tile size in M dimension + - BLK_N: Optimal tile size in N dimension + - BLK_K: Optimal tile size in K dimension + - gsize_m: Optimal workgroup mapping (GROUP_SIZE_M) + + @details + This method orchestrates the full optimization process: + 1. Determines optimal tile dimensions via _get_best_tile_size() + 2. Computes optimal workgroup mapping via _get_gsize_m() + 3. Returns complete configuration for kernel execution + + @see _get_best_tile_size() for tile optimization + @see _get_gsize_m() for workgroup mapping optimization + """ BLK_M, BLK_N, BLK_K = self._get_best_tile_size() gsize_m = self._get_gsize_m(BLK_M, BLK_N, BLK_K) return BLK_M, BLK_N, BLK_K, gsize_m def get_config(self): + """! + @brief Get the optimal configuration tuple for matrix multiplication execution. + + @return tuple[int, int, int, int]: Configuration tuple (BLK_M, BLK_N, BLK_K, GROUP_SIZE_M) + + @details + Returns the pre-computed optimal configuration that includes: + - BLK_M: Tile size in M dimension (rows) + - BLK_N: Tile size in N dimension (columns) + - BLK_K: Tile size in K dimension (inner dimension) + - GROUP_SIZE_M: Workgroup mapping parameter + + This configuration is used by the Triton kernels for optimal performance. + """ return self.config def get_grid(self): + """! + @brief Get the optimal grid size for kernel execution. + + @return int: Grid size (number of workgroups) for optimal load balancing + + @details + Returns the pre-computed grid size that depends on the execution mode: + - Stream-K mode: Uses compute_sk_grid() for dynamic load balancing + - Persistent mode: Uses hardware CU count for static scheduling + + The grid size determines how work is distributed across compute units + and affects load balancing, memory access patterns, and overall performance. + + @see compute_sk_grid() for Stream-K grid computation algorithm + """ return self.grid def partial_tile_size(self, sk_grid: int) -> int: