Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions wave_lang/kernel/wave/schedule_reordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class SchedReorderStrategy(Enum):
TWO_PP_CLUSTER = 0x220
ASYNC_TWO_PP_CLUSTER = 0x2201
MXFP4_PP_CLUSTER = 0x101
FOUR_WAVE_INTERWEAVE = 0x120


def is_pingpong_strategy(strategy):
Expand All @@ -103,6 +104,7 @@ class CompatibleBlockSize:
twoPPConfig = CompatibleBlockSize(128, 128, 64, 16, False, MMA)
asyncTwoPPConfig = CompatibleBlockSize(128, 128, 64, 16, True, MMA)
MXFP4PPConfig = CompatibleBlockSize(256, 128, 256, 4, False, ScaledMMA)
fourWaveConfig = CompatibleBlockSize(64, 64, 32, 16, True, MMA)


class InsertionMode(Enum):
Expand Down Expand Up @@ -546,6 +548,16 @@ def select_reorder_strategy(
hardware_constraint,
):
flat_wave_count = math.prod(hardware_constraint.waves_per_block)
if flat_wave_count == 4 and is_compatible_strategy(
mTile,
nTile,
kTile,
mma_bitwidth,
use_global_to_shared,
mma_type,
fourWaveConfig,
):
return SchedReorderStrategy.FOUR_WAVE_INTERWEAVE
if flat_wave_count != 8:
return SchedReorderStrategy.NONE
if is_compatible_strategy(
Expand Down Expand Up @@ -824,6 +836,93 @@ def transform_MXFP4_PP_clusters(
return clusters


def transform_four_wave_clusters(
mma_nodes,
local_load_lhs,
local_load_rhs,
global_to_shared_lhs,
global_to_shared_rhs,
):
num_slices = 2
sliced_mma_nodes, sliced_local_load_lhs, sliced_local_load_rhs = slice_mma(
mma_nodes, local_load_lhs, local_load_rhs, num_slice=num_slices
)
# Check that we have valid slice size for local_loads and mmas.
assert len(sliced_mma_nodes) == len(sliced_local_load_rhs)
assert len(sliced_mma_nodes) == len(sliced_local_load_lhs)
assert len(sliced_mma_nodes) == num_slices

context_location = mma_nodes and mma_nodes[0].location

clusters = []
tmp_graph = fx.Graph()
# 1st cluster interleaved local and global reads.
clusters.append(sliced_local_load_lhs[0])
clusters.append(sliced_local_load_rhs[0])
barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph)
barrier_op.location = context_location
clusters.append(insert_op_after(barrier_op, sliced_local_load_rhs[0]))

clusters.append(global_to_shared_lhs)
clusters.append(global_to_shared_rhs)
barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph)
barrier_op.location = context_location
clusters.append(insert_op_after(barrier_op, global_to_shared_rhs))

barrier_op = WorkgroupBarrier().add_to_graph(tmp_graph)
barrier_op.location = context_location
clusters.append(insert_op_after(barrier_op, clusters[-1].op))
barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph)
barrier_op.location = context_location
clusters.append(insert_op_after(barrier_op, clusters[-1].op))

# 2nd cluster mma_slice[0].
clusters.append(sliced_mma_nodes[0])
barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph)
barrier_op.location = context_location
clusters.append(insert_op_after(barrier_op, sliced_mma_nodes[0]))

independent_global_count = len(global_to_shared_lhs + global_to_shared_rhs)
barrier_op = MemoryCounterWait(load=independent_global_count).add_to_graph(
tmp_graph
)
barrier_op.location = context_location
clusters.append(insert_op_after(barrier_op, clusters[-1].op))

barrier_op = WorkgroupBarrier().add_to_graph(tmp_graph)
barrier_op.location = context_location
clusters.append(insert_op_after(barrier_op, clusters[-1].op))
barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph)
barrier_op.location = context_location
clusters.append(insert_op_after(barrier_op, clusters[-1].op))

# 3rd cluster local load 2nd slice.
clusters.append(sliced_local_load_lhs[1])
clusters.append(sliced_local_load_rhs[1])
barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph)
barrier_op.location = context_location
clusters.append(insert_op_after(barrier_op, sliced_local_load_rhs[1]))

barrier_op = MemoryCounterWait(load=0).add_to_graph(tmp_graph)
barrier_op.location = context_location
clusters.append(insert_op_after(barrier_op, clusters[-1].op))

barrier_op = WorkgroupBarrier().add_to_graph(tmp_graph)
barrier_op.location = context_location
clusters.append(insert_op_after(barrier_op, clusters[-1].op))
barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph)
barrier_op.location = context_location
clusters.append(insert_op_after(barrier_op, clusters[-1].op))

# 4th cluster mma_slice[1].
clusters.append(sliced_mma_nodes[1])
barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph)
barrier_op.location = context_location
clusters.append(insert_op_after(barrier_op, sliced_mma_nodes[1]))

return clusters


##############################################################
# Helper fn to classify/detect ops.
##############################################################
Expand Down Expand Up @@ -1037,6 +1136,14 @@ def schedule_reordering(
local_write_rhs_scale,
)
clusters = flatten_list(clusters)
elif reorder_strategy == SchedReorderStrategy.FOUR_WAVE_INTERWEAVE:
clusters = transform_four_wave_clusters(
mma_nodes,
local_load_lhs,
local_load_rhs,
global_to_shared_lhs,
global_to_shared_rhs,
)
else:
raise ValueError("Unhandled SchedReorderStrategy case.")
reordered_graph = reorder_graph(graph, clusters)
Expand Down
Loading