-
Notifications
You must be signed in to change notification settings - Fork 6
Description
Required prerequisites
- I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
Motivation
Unlike GPGPU, the minimal object for high-performance computation is a fixed size 2d tile, rather than a scalar. This is fundamentally originated from A4E's SIMD and tile-based ALU design.
Specifically, in addition to conventional elementwise ops, we have some 2d-tile targeted ops. For example, given a 16x32 bf16 tile, we have native support for reduction on first dimension or the second dimension. Also, we have native broadcast support for 1x16 and 16x1 tile to a 16x32 tile. This characteristic well fits the prevalent tile-based programming.
To fully unlock this tile-native computation, we should reshape how developers perceive a variable defined in shared memory. It's no longer a set of scalars for processing. Instead, it is a set of tiles for tensor unit processing, vector unit processing, and data movement inter-core and inter-chip.
Solution
To achieve the goal, we need two front-end designs for tile-based share memory allocation, and tile-based iterations.
Tile-based Shared Memory Allocation
def alloc_shared_tiles(
shape: tuple[Unpack[_Shapes]],
dtype: _DType,
tile_size: tuple[Unpack[_Shapes]],
scope="shared.dyn",
) -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]:
passUsage
A_shared = T.alloc_shared_tiles((block_M, block_N), "float32", tile_size=(32, 32))
A_shared = T.alloc_shared_tiles((block_M, block_N), "float32")Compared with T.alloc_shared, it takes an additional tile_size input, specifying the tile size in the shared memory. Developers can leave this field empty, and let the compiler infers the tile size based on the target info. In A4E, we simply take the 32x32 option for 2d buffer and 32x1 for 1d buffer, which should be optimal in most scenarios.
This frontend function can be directly lowered to T.alloc_buffer, whose shape is normalized based on the inferred tile size. The resulting shape will become a ZZ layout. To be more concrete, if a tensor has more than two dimensions, perform tiling on the last two dimension, e.g., [128, 128] -> [4, 4, 32, 32]. If a tensor has only one dimension, perform transformation directly on it, e.g., [128, 128] -> [4, 32, 1].
With this layout and shape transformations, developers can still loop on the original dimensions, but get tile of size (32, 32) or (32, 1).
To keep the original information, we may need to leverage T.annotate_layout.
Tile-based Iterations
As the tile size is determined by the compiler, to correctly iterate the shared buffer, we should provide a frontend method support.
for i, j in T.foreach_tiles(A_shared, parallel=True):
pass
for i in T.foreach_tiles(B_shared, parallel=False):
passThe T.foreach_tiles should be lowered to T.Parallel or T.Serial based on the parallel field.
The iteration index is computed by tiled buffer size, which is T.ceildiv(block_M, 32) and T.ceildiv(block_N, 32).
for i, j in T.Parallel(T.ceildiv(block_M, 32), T.ceildiv(block_N, 32)):
passExample usage
def rms_norm_splitk(M, N, blk_m, blk_k):
dtype = "float"
nrows, ncols = driver.get_num_rows(), driver.get_num_cols()
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype)
):
# Sharding
T.annotate_mesh_tensor({
"A": {"sharding": {"x": 0, "y": 1},
"tiling": [blk_m, blk_k]},
"B": {"sharding": {"x": 0, "y": 1},
"tiling": [blk_m, blk_k]},
}, nrows=nrows, ncols=ncols)
Sharded_M, Sharded_N = T.get_sharded_shape(A)
with T.Kernel(nrows, ncols) as (rid, cid):
# If the layout is not specified, we should let the compiler search for it
A_shared = T.alloc_shared_tiles((blk_m, blk_k), dtype)
A_accum_shared = T.alloc_shared_tiles((blk_m, blk_k,), dtype)
A_powsum_local_shared = T.alloc_shared_tiles((blk_m,), dtype)
A_powsum_shared = T.alloc_shared_tiles((blk_m, ncols), dtype)
for bx in T.Persistent([T.ceildiv(Sharded_M, blk_m,)]):
T.fill(A_accum_shared, 0.0)
for bk in T.Parallel(T.ceildiv(Sharded_N, blk_k)):
T.mesh_tensor_copy(A, A_shared, coord=[bx, bk])
# Power
for i, j in T.foreach_tile(A_shared, parallel=True):
A_accum_shared[i, j] += A_shared[i, j] * A_shared[i, j]
# Reduce sum within each blk_m
T.reduce_sum(A_accum_shared, A_powsum_local_shared, dim=1)
# All gather powsum
T.comm.col_allgather(A_powsum_local_shared, A_powsum_shared)
# Reduce
T.reduce_sum(A_powsum_shared, A_powsum_local_shared, dim=1)
# Rsqrt and add eps
for i in T.foreach_tile(A_powsum_local_shared, parallel=True):
A_powsum_local_shared[i] = T.rsqrt(A_powsum_local_shared[i] / N) + 1e-12
# Normalize
for bk in T.Parallel(T.ceildiv(Sharded_N, blk_k)):
T.mesh_tensor_copy(A, A_shared, coord=[bx, bk])
for i, j in T.foreach_tile(A_shared, parallel=True):
A_shared[i, j] = A_shared[i, j] * A_powsum_local_shared[i]
T.mesh_tensor_copy(A_shared, B, coord=[bx, bk])