Skip to content

[Feature Request] [Language] Tile-based Shared Memory #7

@SUNMMIO-jlou

Description

@SUNMMIO-jlou

Required prerequisites

  • I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)

Motivation

Unlike GPGPU, the minimal object for high-performance computation is a fixed size 2d tile, rather than a scalar. This is fundamentally originated from A4E's SIMD and tile-based ALU design.

Specifically, in addition to conventional elementwise ops, we have some 2d-tile targeted ops. For example, given a 16x32 bf16 tile, we have native support for reduction on first dimension or the second dimension. Also, we have native broadcast support for 1x16 and 16x1 tile to a 16x32 tile. This characteristic well fits the prevalent tile-based programming.

To fully unlock this tile-native computation, we should reshape how developers perceive a variable defined in shared memory. It's no longer a set of scalars for processing. Instead, it is a set of tiles for tensor unit processing, vector unit processing, and data movement inter-core and inter-chip.

Solution

To achieve the goal, we need two front-end designs for tile-based share memory allocation, and tile-based iterations.

Tile-based Shared Memory Allocation

def alloc_shared_tiles(
    shape: tuple[Unpack[_Shapes]],
    dtype: _DType,
    tile_size: tuple[Unpack[_Shapes]],
    scope="shared.dyn",
) -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]:
    pass

Usage

A_shared = T.alloc_shared_tiles((block_M, block_N), "float32", tile_size=(32, 32))
A_shared = T.alloc_shared_tiles((block_M, block_N), "float32")

Compared with T.alloc_shared, it takes an additional tile_size input, specifying the tile size in the shared memory. Developers can leave this field empty, and let the compiler infers the tile size based on the target info. In A4E, we simply take the 32x32 option for 2d buffer and 32x1 for 1d buffer, which should be optimal in most scenarios.

This frontend function can be directly lowered to T.alloc_buffer, whose shape is normalized based on the inferred tile size. The resulting shape will become a ZZ layout. To be more concrete, if a tensor has more than two dimensions, perform tiling on the last two dimension, e.g., [128, 128] -> [4, 4, 32, 32]. If a tensor has only one dimension, perform transformation directly on it, e.g., [128, 128] -> [4, 32, 1].

With this layout and shape transformations, developers can still loop on the original dimensions, but get tile of size (32, 32) or (32, 1).

To keep the original information, we may need to leverage T.annotate_layout.

Tile-based Iterations

As the tile size is determined by the compiler, to correctly iterate the shared buffer, we should provide a frontend method support.

for i, j in T.foreach_tiles(A_shared, parallel=True):
    pass

for i in T.foreach_tiles(B_shared, parallel=False):
   pass

The T.foreach_tiles should be lowered to T.Parallel or T.Serial based on the parallel field.
The iteration index is computed by tiled buffer size, which is T.ceildiv(block_M, 32) and T.ceildiv(block_N, 32).

for i, j in T.Parallel(T.ceildiv(block_M, 32), T.ceildiv(block_N, 32)):
    pass

Example usage

def rms_norm_splitk(M, N, blk_m, blk_k):
    dtype = "float"
    nrows, ncols = driver.get_num_rows(), driver.get_num_cols()

    @T.prim_func
    def main(
        A: T.Tensor((M, N), dtype),
        B: T.Tensor((M, N), dtype)
    ):
        # Sharding
        T.annotate_mesh_tensor({
            "A": {"sharding": {"x": 0, "y": 1},
                  "tiling": [blk_m, blk_k]},
            "B": {"sharding": {"x": 0, "y": 1},
                  "tiling": [blk_m, blk_k]},
        }, nrows=nrows, ncols=ncols)
        Sharded_M, Sharded_N = T.get_sharded_shape(A)

        with T.Kernel(nrows, ncols) as (rid, cid):
            # If the layout is not specified, we should let the compiler search for it
            A_shared = T.alloc_shared_tiles((blk_m, blk_k), dtype)
            A_accum_shared = T.alloc_shared_tiles((blk_m, blk_k,), dtype)
            A_powsum_local_shared = T.alloc_shared_tiles((blk_m,), dtype)
            A_powsum_shared = T.alloc_shared_tiles((blk_m, ncols), dtype)
            for bx in T.Persistent([T.ceildiv(Sharded_M, blk_m,)]):
                T.fill(A_accum_shared, 0.0)
                for bk in T.Parallel(T.ceildiv(Sharded_N, blk_k)):
                    T.mesh_tensor_copy(A, A_shared, coord=[bx, bk])
                    # Power
                    for i, j in T.foreach_tile(A_shared, parallel=True):
                        A_accum_shared[i, j] += A_shared[i, j] * A_shared[i, j]

                # Reduce sum within each blk_m
                T.reduce_sum(A_accum_shared, A_powsum_local_shared, dim=1)
                # All gather powsum
                T.comm.col_allgather(A_powsum_local_shared, A_powsum_shared)
                # Reduce
                T.reduce_sum(A_powsum_shared, A_powsum_local_shared, dim=1)
                # Rsqrt and add eps
                for i in T.foreach_tile(A_powsum_local_shared, parallel=True):
                    A_powsum_local_shared[i] = T.rsqrt(A_powsum_local_shared[i] / N) + 1e-12
                # Normalize
                for bk in T.Parallel(T.ceildiv(Sharded_N, blk_k)):
                    T.mesh_tensor_copy(A, A_shared, coord=[bx, bk])
                    for i, j in T.foreach_tile(A_shared, parallel=True):
                        A_shared[i, j] = A_shared[i, j] * A_powsum_local_shared[i]
                T.mesh_tensor_copy(A_shared, B, coord=[bx, bk])

Sub-issues

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions