Skip to content

Commit

Permalink
[LLVMGPU] Use LLVMGPUDistribute for small input scatters (#19670)
Browse files Browse the repository at this point in the history
If the scattered slice is small then we will end up only distributing
the batch dimensions to workgroups. This causes bufferization to fail in
LLVMGPUTileAndFuse because the workgroup level `extract_slice` will fold
away.

Fixes #19639
  • Loading branch information
qedawkins authored Jan 10, 2025
1 parent f7a2157 commit 9f93691
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,19 @@ LogicalResult setScatterLoweringConfig(IREE::GPU::TargetAttr target,
}
}

int64_t numBatch = scatter.getBatchRank();
// Currently bufferization will fail if the only dimension distributed to
// workgroups is the batch dims because the workgroup level slice will fold
// away and cause a mismatch.
// TODO(qedawkins): Support this case.
if (llvm::all_of_zip(llvm::drop_begin(workgroupTileSizes, numBatch),
llvm::drop_begin(loopBounds, numBatch),
[](int64_t tileSize, int64_t bound) {
return tileSize == bound || tileSize == 0;
})) {
return failure();
}

// Attach the MMA schedule as an attribute to the entry point export function
// for later access in the pipeline.
MLIRContext *context = scatter.getContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,10 @@ func.func @only_scattered_dim(%arg0: tensor<48xf32>,
}

// CHECK-LABEL: func.func @only_scattered_dim
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64
// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUDistribute workgroup_size = [128, 1, 1] subgroup_size = 64

// CHECK: linalg_ext.scatter {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: thread = [1]
// CHECK-SAME: workgroup = [48]
// CHECK: linalg_ext.scatter {{.*}}lowering_config = #iree_codegen.lowering_config
// CHECK-SAME: tile_sizes = {{\[}}[128]]

// -----

Expand Down

0 comments on commit 9f93691

Please sign in to comment.