From c3517cb1ce8d40d13a333e993e9e04a7c76a7626 Mon Sep 17 00:00:00 2001 From: xiaoyao-NKU <18946090101@163.com> Date: Thu, 22 Jan 2026 10:51:25 +0800 Subject: [PATCH 1/5] Implement Inter-core Communication in TileOp way: broadcast, put, all_gather --- src/op/comm.cc | 585 ++++++++++++++++++++++++++++++++++ src/op/comm.h | 136 ++++++++ tilelang/language/__init__.py | 2 +- tilelang/language/comm.py | 412 ++++++++++++++++++++++++ 4 files changed, 1134 insertions(+), 1 deletion(-) create mode 100644 src/op/comm.cc create mode 100644 src/op/comm.h create mode 100644 tilelang/language/comm.py diff --git a/src/op/comm.cc b/src/op/comm.cc new file mode 100644 index 000000000..19b1edbc0 --- /dev/null +++ b/src/op/comm.cc @@ -0,0 +1,585 @@ +/*! + * \file tl/op/comm.cc + * \brief Implementation of Inter-core Communication Operators + */ + +#include "comm.h" + +#include +#include +#include + +#include "../target/utils.h" +#include "copy.h" +#include "reduce.h" +#include "utils.h" + +namespace tvm { +namespace tl { + +#define TIR_DEFINE_TL_BUILTIN(OpName) \ + const Op &OpName() { \ + static const Op &op = Op::Get("tl." #OpName); \ + return op; \ + } \ + TVM_REGISTER_OP("tl." #OpName) \ + .set_attr("TScriptPrinterName", #OpName) +TIR_DEFINE_TL_BUILTIN(comm_barrier) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(comm_fence) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(CoreId).set_num_inputs(1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(comm_current_core) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(comm_is_current_core) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(broadcast_) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); +// src_buffer, dst_buffer, size(IntImm), src_core(IntImm) +// direction(0: horizontal, 1: vertical), +// *mask(optional: IntImm list of core ids to exclude) + +using namespace tir; + +Broadcast::Broadcast(Array args) { + ObjectPtr node = tvm::ffi::make_object(); + node->src_expr = args[0]; + node->dst_expr = args[1]; + Array rgs[2]; + Buffer bf[2]; + for (int i = 0; i < 2; i++) { + auto region = NormalizeToBufferRegion(args[i]); + + rgs[i] = region->region; + bf[i] = region->buffer; + } + std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); + std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); + node->size = Downcast(args[2]); + node->dst_offset = Downcast(args[3]); + node->src_core = Downcast(args[4]); + node->direction = Downcast(args[5])->value; + data_ = std::move(node); +} + +TileOperator BroadcastNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return Broadcast(op); +} + +LayoutMap BroadcastNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + Array args; + args.push_back(src_expr); + args.push_back(dst_expr); + Copy copy_op = Copy(args); + LayoutMap out_layout = copy_op->InferLayout(T, level); + return out_layout; +} + +// int get_target_mesh_nrows(Target target) { +// auto mattr = target->GetAttr>("mattr").value(); +// int x = 0; +// for (size_t i = 0; i < mattr.size(); i++) { +// std::string m = mattr[i]; +// if (m.find("device_mesh_nrow_") != std::string::npos) { +// std::string s = m.substr(m.find_last_of('_') + 1);; +// try { +// x = std::stoi(s); +// } catch (const std::invalid_argument& e) { +// x = -1; +// } catch (const std::out_of_range& e) { +// x = -1; +// } +// } +// } +// ICHECK(x != 0) << "Device mesh row number not found."; +// ICHECK(x > 0) << "Invalid device mesh row number: "; +// return x; +// } + +int get_target_mesh(Target target, int axis) { + auto mattr = target->GetAttr>("mattr").value(); + int x = 0; + std::string axis_str; + if (axis == 0) { + axis_str = "device_mesh_ncol_"; + } else if (axis == 1) { + axis_str = "device_mesh_nrow_"; + } else { + LOG(FATAL) << "Invalid axis " << axis << " for getting mesh dimension."; + } + for (size_t i = 0; i < mattr.size(); i++) { + std::string m = mattr[i]; + if (m.find(axis_str) != std::string::npos) { + std::string s = m.substr(m.find_last_of('_') + 1);; + try { + x = std::stoi(s); + } catch (const std::invalid_argument& e) { + x = -1; + } catch (const std::out_of_range& e) { + x = -1; + } + } + } + ICHECK(x != 0) << axis_str << " not found."; + ICHECK(x > 0) << "Invalid " << axis_str; + return x; +} + + +Stmt BroadcastNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + Target target = T.target; + ICHECK(TargetIsSunmmio(target)) << "Broadcast only supports SUNMMIO targets."; + int mesh_x = get_target_mesh(target, 0); + int mesh_y = get_target_mesh(target, 1); + + // check for valid core id + ICHECK(src_core->value >= 0 and src_core->value < mesh_x * mesh_y) + << "Source core id " << src_core->value << " out of range [0, " + << mesh_x * mesh_y << ")"; + + // check for src and dst buffer sizes + PrimExpr src_elements = 1; + for (size_t i = 0; i < src_range.size(); i++) { + src_elements *= src_range[i]->extent; + } + src_elements = analyzer->Simplify(src_elements); + PrimExpr dst_elements = 1; + for (size_t i = 0; i < dst_range.size(); i++) { + dst_elements *= dst_range[i]->extent; + } + dst_elements = analyzer->Simplify(dst_elements); + ICHECK(Downcast(src_elements)->value <= + Downcast(dst_elements)->value) + << "Source buffer size larger than destination buffer size: " + << src_elements << " vs " << dst_elements; + ICHECK(size->value <= Downcast(src_elements)->value) + << "Broadcast size larger than data size: " << size->value << " vs " + << Downcast(src_elements)->value; + + // check for size and dst_offset + PrimExpr broadcast_elements; + if (size->value < 0) { + broadcast_elements = src_elements; + } else { + broadcast_elements = size; + } + ICHECK((Downcast(broadcast_elements)->value) <= + Downcast(src_elements)->value) + << "Broadcast size Larger than source buffer size: " + << (Downcast(broadcast_elements)->value) << " vs " + << Downcast(src_elements)->value; + ICHECK((Downcast(broadcast_elements)->value + dst_offset->value) <= + Downcast(dst_elements)->value) + << "Broadcast size + dst_offset larger than destination buffer size: " + << (Downcast(broadcast_elements)->value + dst_offset->value) + << " vs " << Downcast(dst_elements)->value; + + // check for valid direction + if (direction != 0 and direction != 1 and direction != 2) { + LOG(FATAL) << "Invalid broadcast direction " << direction + << ", must be 0 (horizontal) or 1 (vertical) or 2 (all)."; + } + + // all checks passed, generate the call + PrimExpr src_addr = src.access_ptr(1, DataType::Handle(), 1, 0, src_elements); + PrimExpr dst_addr = + dst.access_ptr(2, DataType::Handle(), 1, + Downcast(dst_offset->value), src_elements); + int src_core_y = src_core->value % mesh_y; + + if (direction == 0 or direction == 1) { + // 1D broadcast + Array args; + args.push_back(src_addr); + args.push_back(dst_addr); + args.push_back(Downcast(broadcast_elements)); + args.push_back(src_core); + args.push_back(direction); + Stmt broadcast = Evaluate(Call(DataType::Handle(), broadcast_(), args)); + return broadcast; + } else { + // 2D broadcast + Array seq; + // vertical broadcast + Array args; + args.push_back(src_addr); + args.push_back(dst_addr); + args.push_back(Downcast(broadcast_elements)); + args.push_back(src_core); + args.push_back(1); // direction: vertical + Stmt broadcast = Evaluate(Call(DataType::Handle(), broadcast_(), args)); + seq.push_back(broadcast); + // herizontal broadcast + for (int i = 0; i < mesh_x; i++) { + Array args; + args.push_back(dst.access_ptr(1, DataType::Handle(), 1, 0, dst_elements)); + args.push_back(dst.access_ptr(2, DataType::Handle(), 1, 0, dst_elements)); + args.push_back(Downcast(broadcast_elements)); + args.push_back(int(i * mesh_y) + src_core_y); + args.push_back(0); // direction: horeizontal + args.push_back( + IntImm(DataType::Int(32), src_core_y)); // mask: current core only + Stmt broadcast = Evaluate(Call(DataType::Handle(), broadcast_(), args)); + seq.push_back(broadcast); + } + return SeqStmt::Flatten(seq); + } +} + +TIR_REGISTER_TL_TILE_OP(Broadcast, comm_broadcast) + .set_num_inputs(6) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +Put::Put(Array args) { + ObjectPtr node = tvm::ffi::make_object(); + node->src_expr = args[0]; + node->dst_expr = args[1]; + Array rgs[2]; + Buffer bf[2]; + for (int i = 0; i < 2; i++) { + auto region = NormalizeToBufferRegion(args[i]); + rgs[i] = region->region; + bf[i] = region->buffer; + } + std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); + std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); + node->size = Downcast(args[2]); + node->src_core = Downcast(args[3]); + node->dst_core = Downcast(args[4]); + data_ = std::move(node); +} + +TileOperator PutNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return Put(op); +} + +LayoutMap PutNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + Array args; + args.push_back(src_expr); + args.push_back(dst_expr); + Copy copy_op = Copy(args); + LayoutMap out_layout = copy_op->InferLayout(T, level); + return out_layout; +} + +Stmt PutNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + Target target = T.target; + ICHECK(TargetIsSunmmio(target)) << "Put only supports SUNMMIO targets."; + int mesh_x = get_target_mesh(target, 0); + int mesh_y = get_target_mesh(target, 1); + + // check for valid core id + ICHECK(src_core->value >= 0 and src_core->value < mesh_x * mesh_y) + << "Source core id " << src_core->value << " out of range [0, " + << mesh_x * mesh_y << ")"; + ICHECK(dst_core->value >= 0 and dst_core->value < mesh_x * mesh_y) + << "Destination core id " << dst_core->value << " out of range [0, " + << mesh_x * mesh_y << ")"; + + // check for src and dst buffer sizes + PrimExpr src_elements = 1; + for (size_t i = 0; i < src_range.size(); i++) { + src_elements *= src_range[i]->extent; + } + src_elements = analyzer->Simplify(src_elements); + PrimExpr dst_elements = 1; + for (size_t i = 0; i < dst_range.size(); i++) { + dst_elements *= dst_range[i]->extent; + } + dst_elements = analyzer->Simplify(dst_elements); + ICHECK(Downcast(src_elements)->value <= + Downcast(dst_elements)->value) + << "Source buffer size larger than destination buffer size: " + << src_elements << " vs " << dst_elements; + ICHECK(size->value <= Downcast(src_elements)->value) + << "Broadcast size larger than data size: " << size->value << " vs " + << Downcast(src_elements)->value; + + // check for size + PrimExpr broadcast_elements; + if (size->value < 0) { + broadcast_elements = src_elements; + } else { + broadcast_elements = size; + } + ICHECK((Downcast(broadcast_elements)->value) <= + Downcast(src_elements)->value) + << "Broadcast size Larger than source buffer size: " + << (Downcast(broadcast_elements)->value) << " vs " + << Downcast(src_elements)->value; + ICHECK((Downcast(broadcast_elements)->value) <= + Downcast(dst_elements)->value) + << "Broadcast size larger than destination buffer size: " + << (Downcast(broadcast_elements)->value) << " vs " + << Downcast(dst_elements)->value; + + // all checks passed, generate the call + PrimExpr src_addr = src.access_ptr(1, DataType::Handle(), 1, 0, src_elements); + PrimExpr dst_addr = dst.access_ptr(2, DataType::Handle(), 1, 0, dst_elements); + int src_core_x = src_core->value / mesh_y; + int src_core_y = src_core->value % mesh_y; + int dst_core_x = dst_core->value / mesh_y; + int dst_core_y = dst_core->value % mesh_y; + + if (src_core_x == dst_core_x) { + // 1D put via horizontal communication + Array args; + args.push_back(src_addr); + args.push_back(dst_addr); + args.push_back(Downcast(broadcast_elements)); + args.push_back(src_core); + args.push_back(0); // direction: horizontal + for (int j = 0; j < mesh_y; j++) { + if (j != dst_core_y) { + args.push_back( + IntImm(DataType::Int(32), j)); // mask: all cores except dst_core_y + } + } + Stmt put = Evaluate(Call(DataType::Handle(), broadcast_(), args)); + return put; + } else if (src_core_y == dst_core_y) { + // 1D put via vertical communication + Array args; + args.push_back(src_addr); + args.push_back(dst_addr); + args.push_back(Downcast(broadcast_elements)); + args.push_back(src_core); + args.push_back(1); // direction: vertical + for (int i = 0; i < mesh_x; i++) { + if (i != dst_core_x) { + args.push_back( + IntImm(DataType::Int(32), i)); // mask: all cores except dst_core_x + } + } + Stmt put = Evaluate(Call(DataType::Handle(), broadcast_(), args)); + return put; + } else { + Array seq; + // vertical transfer from src core to intermediate core + int intermediate_core_id = src_core_x * mesh_y + dst_core_y; + Array args1; + args1.push_back(src_addr); + args1.push_back(dst_addr); + args1.push_back(Downcast(broadcast_elements)); + args1.push_back(src_core); + args1.push_back(1); // direction: vertical + for (int i = 0; i < mesh_x; i++) { + if (i != dst_core_x) { + args1.push_back( + IntImm(DataType::Int(32), i)); // mask: all cores except dst_core_x + } + } + Stmt put1 = Evaluate(Call(DataType::Handle(), broadcast_(), args1)); + seq.push_back(put1); + // horizontal transfer from intermediate core to dst core + Array args2; + args2.push_back(dst.access_ptr(1, DataType::Handle(), 1, 0, src_elements)); + args2.push_back(dst.access_ptr(2, DataType::Handle(), 1, 0, dst_elements)); + args2.push_back(Downcast(broadcast_elements)); + args2.push_back(IntImm(DataType::Int(32), intermediate_core_id)); + args2.push_back(0); // direction: horizontal + for (int j = 0; j < mesh_y; j++) { + if (j != dst_core_y) { + args2.push_back( + IntImm(DataType::Int(32), j)); // mask: all cores except dst_core_y + } + } + Stmt put2 = Evaluate(Call(DataType::Handle(), broadcast_(), args2)); + seq.push_back(put2); + return SeqStmt::Flatten(seq); + } +} + +TIR_REGISTER_TL_TILE_OP(Put, comm_put) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +Allgather::Allgather(Array args) { + ObjectPtr node = tvm::ffi::make_object(); + node->send = args[0]; + node->recv = args[1]; + node->direction = Downcast(args[2])->value; + node->size = Downcast(args[3]); + data_ = std::move(node); +} + +TileOperator AllgatherNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return Allgather(op); +} + +Layout AllgatherNode::ComputeLinearLayout(const Buffer &shared_tensor) const { + Array input_size = shared_tensor->shape; + Array forward_vars; + for (size_t i = 0; i < input_size.size(); i++) { + forward_vars.push_back(InputPlaceholder(i)); + } + // [i, j] -> [i // 256, j // 256, i % 256, j % 256] + Array forward_index; + for (size_t i = 0; i < input_size.size(); i++) { + forward_index.push_back(FloorDiv(forward_vars[i], 256)); + } + for (size_t i = 0; i < input_size.size(); i++) { + forward_index.push_back(FloorMod(forward_vars[i], 256)); + } + return Layout(input_size, forward_index); +} + +LayoutMap AllgatherNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + Buffer recv_buffer = NormalizeToBufferRegion(recv)->buffer; + Layout linear_layout = ComputeLinearLayout(recv_buffer); + return Map({{recv_buffer, linear_layout}}); +} + +Stmt AllgatherNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + Target target = T.target; + ICHECK(TargetIsSunmmio(target)) << "Allgather only supports SUNMMIO targets."; + int mesh_x = get_target_mesh(target, 0); + int mesh_y = get_target_mesh(target, 1); + + Array send_range, recv_range; + auto send_region = NormalizeToBufferRegion(send); + auto recv_region = NormalizeToBufferRegion(recv); + send_range = send_region->region; + recv_range = recv_region->region; + + int recv_num = 1; + if (direction == 0) { // horizontal + recv_num = mesh_y; + } else if (direction == 1) { // vertical + recv_num = mesh_x; + } else if (direction == 2) { // all + recv_num = mesh_x * mesh_y; + } else { + // invalid direction + ICHECK(false) << "Invalid direction value for allgather: " << direction; + } + + PrimExpr send_elements = 1; + for (size_t i = 0; i < send_range.size(); i++) { + send_elements *= send_range[i]->extent; + } + send_elements = analyzer->Simplify(send_elements); + PrimExpr recv_elements = 1; + for (size_t i = 0; i < recv_range.size(); i++) { + recv_elements *= recv_range[i]->extent; + } + recv_elements = analyzer->Simplify(recv_elements); + // check for buffer sizes + ICHECK(Downcast(send_elements)->value * recv_num <= + Downcast(recv_elements)->value) + << "Receive buffer size not enough for allgather: required " + << (Downcast(send_elements)->value * recv_num) << ", but got " + << Downcast(recv_elements)->value; + + // all checks passed, generate the calls + Array bcast_stmts; + + if (direction == 0) { // horizontal + for (int i = 0; i < mesh_x; i++) { + for (size_t j = 0; j < mesh_y; j++) { + Array args; + args.push_back(send); + args.push_back(recv); + args.push_back(size); + args.push_back(IntImm(DataType::Int(32), j) * send_elements); // offset + args.push_back(IntImm(DataType::Int(32), i * mesh_y + j)); // src_core + args.push_back(0); // direction: horizontal + Broadcast bcast = Broadcast(args); + Stmt bcast_stmt = bcast->Lower(T, analyzer); + bcast_stmts.push_back(bcast_stmt); + } + } + } else if (direction == 1) { // vertical + for (int j = 0; j < mesh_y; j++) { + for (size_t i = 0; i < mesh_x; i++) { + Array args; + args.push_back(send); + args.push_back(recv); + args.push_back(size); + args.push_back(IntImm(DataType::Int(32), i) * send_elements); // offset + args.push_back(IntImm(DataType::Int(32), i * mesh_y + j)); // src_core + args.push_back(1); // direction: vertical + Broadcast bcast = Broadcast(args); + Stmt bcast_stmt = bcast->Lower(T, analyzer); + bcast_stmts.push_back(bcast_stmt); + } + } + } else if (direction == 2) { // all + // first do horizontal allgather + for (int i = 0; i < mesh_x; i++) { + for (size_t j = 0; j < mesh_y; j++) { + Array args; + args.push_back(send); + args.push_back(recv); + args.push_back(size); + args.push_back(IntImm(DataType::Int(32), i * mesh_y + j) * + send_elements); // offset + args.push_back(IntImm(DataType::Int(32), i * mesh_y + j)); // src_core + args.push_back(0); // direction: horizontal + Broadcast bcast = Broadcast(args); + Stmt bcast_stmt = bcast->Lower(T, analyzer); + bcast_stmts.push_back(bcast_stmt); + } + } + // then do vertical allgather + Buffer recv_buffer = recv_region->buffer; + int allgather_size = (size->value < 0) + ? Downcast(send_elements)->value * mesh_y + : size->value * mesh_y; + + for (int j = 0; j < mesh_y; j++) { + for (size_t i = 0; i < mesh_x; i++) { + Array args; + args.push_back(recv_buffer.access_ptr( + 1, DataType::Handle(), 1, + IntImm(DataType::Int(32), i * mesh_y) * send_elements, + IntImm(DataType::Int(32), mesh_y) * send_elements)); + args.push_back(recv_buffer.access_ptr( + 2, DataType::Handle(), 1, + IntImm(DataType::Int(32), i * mesh_y) * send_elements, + IntImm(DataType::Int(32), mesh_y) * send_elements)); + args.push_back(IntImm(DataType::Int(32), allgather_size)); // size + args.push_back(IntImm(DataType::Int(32), i * mesh_y + j)); // src_core + args.push_back(1); // direction: vertical + args.push_back(IntImm(DataType::Int(32), i)); // mask: current row only + Stmt bcast_stmt = + Evaluate(Call(DataType::Handle(), broadcast_(), args)); + bcast_stmts.push_back(bcast_stmt); + } + } + } + return SeqStmt::Flatten(bcast_stmts); +} + +TIR_REGISTER_TL_TILE_OP(Allgather, comm_allgather) + .set_num_inputs(4) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TVM_FFI_STATIC_INIT_BLOCK() { + PutNode::RegisterReflection(); + BroadcastNode::RegisterReflection(); + AllgatherNode::RegisterReflection(); +} + +} // namespace tl +} // namespace tvm diff --git a/src/op/comm.h b/src/op/comm.h new file mode 100644 index 000000000..e0b4ffe43 --- /dev/null +++ b/src/op/comm.h @@ -0,0 +1,136 @@ +/*! + * \file tl/op/comm.h + * \brief Implementation of Inter-core Communication Operators + */ + +#ifndef TVM_TL_OP_COMM_H_ +#define TVM_TL_OP_COMM_H_ + +#include "operator.h" + +namespace tvm { +namespace tl { + +TVM_DLL const Op &CoreId(); +TVM_DLL const Op &comm_current_core(); +TVM_DLL const Op &comm_is_current_core(); +TVM_DLL const Op &comm_barrier(); +TVM_DLL const Op &comm_fence(); +TVM_DLL const Op &broadcast_(); + +using namespace tir; + +class BroadcastNode : public TileOperatorNode { +public: + Buffer src, dst; + Array src_range, dst_range; + PrimExpr src_expr, dst_expr; + IntImm size; + IntImm dst_offset; + IntImm src_core; + int direction; + // Array group; + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.comm_broadcast", BroadcastNode, + TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("src", &BroadcastNode::src) + .def_ro("dst", &BroadcastNode::dst) + .def_ro("src_range", &BroadcastNode::src_range) + .def_ro("dst_range", &BroadcastNode::dst_range) + .def_ro("src_core", &BroadcastNode::src_core) + .def_ro("direction", &BroadcastNode::direction) + .def_ro("size", &BroadcastNode::size) + .def_ro("dst_offset", &BroadcastNode::dst_offset); + // .def_ro("group", &BroadcastNode::group); + } + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + TileOperator Clone() const; +}; + +class Broadcast : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Broadcast, TileOperator, + BroadcastNode); + TVM_DLL Broadcast(Array args); + static const Op &Get(); +}; + +class PutNode : public TileOperatorNode { +public: + Buffer src, dst; + Array src_range, dst_range; + PrimExpr src_expr, dst_expr; + IntImm src_core, dst_core; + IntImm size; + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.comm_put", PutNode, TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("src", &PutNode::src) + .def_ro("dst", &PutNode::dst) + .def_ro("src_range", &PutNode::src_range) + .def_ro("dst_range", &PutNode::dst_range) + .def_ro("src_core", &PutNode::src_core) + .def_ro("dst_core", &PutNode::dst_core) + .def_ro("size", &PutNode::size); + } + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + TileOperator Clone() const; +}; + +class Put : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Put, TileOperator, PutNode); + TVM_DLL Put(Array args); + static const Op &Get(); +}; + +class AllgatherNode : public TileOperatorNode { +public: + PrimExpr send, recv; + int direction; + IntImm size; + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.comm_allgather", AllgatherNode, + TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("send", &AllgatherNode::send) + .def_ro("recv", &AllgatherNode::recv) + .def_ro("direction", &AllgatherNode::direction) + .def_ro("size", &AllgatherNode::size); + } + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + Layout ComputeLinearLayout(const Buffer &shared_tensor) const; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + TileOperator Clone() const; +}; + +class Allgather : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Allgather, TileOperator, + AllgatherNode); + TVM_DLL Allgather(Array args); + static const Op &Get(); +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_COMM_H_ diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index a4a48310a..ea77c0d5c 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -106,7 +106,7 @@ from .annotations import ( # noqa: F401 use_swizzle, annotate_layout, annotate_safe_value, annotate_l2_hit_ratio, ) - +from . import comm # noqa: F401 def import_source(source: str | None = None): # source is the source code to be imported diff --git a/tilelang/language/comm.py b/tilelang/language/comm.py new file mode 100644 index 000000000..c42f26c2a --- /dev/null +++ b/tilelang/language/comm.py @@ -0,0 +1,412 @@ +"""Communication intrinsics wrappers for TileLang. + +This module provides small helper functions that prepare arguments and +emit TIR intrinsics for inter-core communication on a target mesh. +""" + +from __future__ import annotations + +from typing import Literal +from collections.abc import Iterable + +from tvm import tir +import tilelang.language as T +from tilelang.utils.language import to_buffer_region + + +DIRECTION_MAP = {"horizontal": 0, "h": 0, "vertical": 1, "v": 1, "all": 2, "a": 2} + + +def get_target_mesh_shape(target: str = "auto") -> dict[str, int]: + """Get the shape of the target mesh as a dictionary with 'x' and 'y' keys. + Args: + target: The target mesh type. Supported values are + 'sunmmio-a4e', 'sunmmio-a4e-lite', and 'auto'. If 'auto' is specified, + the function defaults to 'sunmmio-a4e'. + Returns: + A dictionary with integer keys 'x' and 'y' representing + the 2D mesh size in each dimension. + Raises: + ValueError: If an unknown target is specified. + """ + if target == "auto": + target = "sunmmio-a4e" + + if target == "sunmmio-a4e": + return {"x": 4, "y": 4} + elif target == "sunmmio-a4e-lite": + return {"x": 2, "y": 4} + else: + raise ValueError(f"Unknown target: {target}") + + +def core_tuple_to_id(core_id: tuple[int, int]) -> int: + """Convert 2D (row, col) coordinates on the mesh into a linear core id. + + Parameters + ---------- + core_id : tuple[int, int] + A tuple specifying the (row, col) coordinates of the core on the mesh. + + Returns + ------- + int + The linear core id corresponding to the provided coordinates. + + Notes + ----- + The conversion uses the current target mesh shape obtained via + get_target_mesh_shape("auto"). + """ + mesh_shape = get_target_mesh_shape("auto") + row, col = core_id + assert ( + 0 <= row < mesh_shape["x"] + ), f"Row {row} out of bounds for mesh shape {mesh_shape}." + assert ( + 0 <= col < mesh_shape["y"] + ), f"Col {col} out of bounds for mesh shape {mesh_shape}." + core_id_value = row * mesh_shape["y"] + col + return core_id_value + + +def core_id_to_tuple(core_id: tir.Call) -> tuple[int, int]: + """Convert a linear core id into 2D (row, col) coordinates on the mesh. + + Parameters + ---------- + core_id : tir.Call + A linear core identifier (or a TIR expression that yields one). + + Returns + ------- + tuple[int, int] + The (row, col) coordinates corresponding to the linear core id. + + Notes + ----- + The conversion uses the current target mesh shape obtained via + get_target_mesh_shape("auto"). + """ + mesh_shape = get_target_mesh_shape("auto") + core_id_value = core_id + row = core_id_value // mesh_shape["y"] + col = core_id_value % mesh_shape["y"] + return (row, col) + + +def CoreId(core_id: int | tuple[int, int]): + """Convert a core identifier to a linear core ID for the target mesh. + + Parameters + ---------- + core_id : int or tuple[int, int] + Either a linear core id (int) or a 2-tuple (row, col) specifying the + core coordinates on the target mesh. + + Returns + ------- + int + The linear core id mapped into [0, mesh_x * mesh_y). + + Raises + ------ + AssertionError, ValueError + If the provided coordinates are out of bounds or the type is invalid. + """ + mesh_shape = get_target_mesh_shape("auto") + if isinstance(core_id, tuple): + row, col = core_id + assert ( + 0 <= row < mesh_shape["x"] + ), f"Row {row} out of bounds for mesh shape {mesh_shape}" + assert ( + 0 <= col < mesh_shape["y"] + ), f"Col {col} out of bounds for mesh shape {mesh_shape}" + # Convert 2D coordinates into a linear core id. + core_id_value = row * mesh_shape["x"] + col + elif isinstance(core_id, int): + core_id_value = core_id + assert ( + 0 <= core_id_value < mesh_shape["x"] * mesh_shape["y"] + ), f"Core ID {core_id_value} out of bounds for mesh shape {mesh_shape}" + else: + raise ValueError("core_id must be either a tuple[int, int] or an int.") + return tir.call_intrin("handle", tir.op.Op.get("tl.CoreId"), core_id_value) + + +def current_core(): + """Get the current core's identifier. + + Returns + ------- + tir.Call + The TIR intrinsic call handle for `tl.comm_current_core`. + + Examples + -------- + >>> current_core() + """ + return tir.call_intrin("handle", tir.op.Op.get("tl.comm_current_core")) + + +def broadcast( + src: T.Buffer, + dst: T.Buffer, + src_core: tuple[int, int], + direction: Literal["horizontal", "h", "vertical", "v", "all", "a"] = "all", + size: int = -1, +): + """Broadcast data from a source buffer on a specific source core to a destination buffer + on all cores in the specified direction by emitting the TIR intrinsic tl.tileop.comm_broadcast. + Parameters + ---------- + src : T.Buffer + Source buffer containing data to broadcast. + dst : T.Buffer + Destination buffer to receive the broadcasted data. + src_core : tuple[int, int] + (row, col) coordinates of the source core on the target mesh. + direction : Literal["horizontal", "h", "vertical", "v", "all", "a"] + Direction of broadcast: "horizontal" (or "h") for row-wise, "vertical" (or "v") for column-wise, + and "all" (or "a") for all cores. + size : int + Number of elements to broadcast. If -1, the entire source buffer is used. + Returns + ------- + tir.Call + The TIR intrinsic call handle for `tl.tileop.comm_broadcast`. + Examples + -------- + >>> broadcast(A, B, (1, 2), direction="horizontal") + """ + assert ( + src.dtype == dst.dtype + ), f"Source and destination buffer dtypes must match for broadcast. Got {src.dtype} vs {dst.dtype}." + if len(src.shape) != len(dst.shape): + raise ValueError( + "Source and destination buffer must have the same number of dimensions for broadcast." + ) + for i in range(len(src.shape)): + assert (src.shape[i] == dst.shape[i] or src.shape[i] == 1 or dst.shape[i] == 1 + ), f"Source buffer shape and destination buffer shape must match for broadcast. Got {src.shape} vs {dst.shape}." + + mesh_shape = get_target_mesh_shape("auto") + assert ( + isinstance(src_core, tuple) and len(src_core) == 2 + ), "src_core must be a tuple of (row, col)." + assert ( + 0 <= src_core[0] < mesh_shape["x"] + ), f"src_core row {src_core[0]} out of bounds for mesh shape {mesh_shape}." + assert ( + 0 <= src_core[1] < mesh_shape["y"] + ), f"src_core col {src_core[1]} out of bounds for mesh shape {mesh_shape}." + + src_elements = 1 + for dim in src.shape: + src_elements *= dim + assert isinstance(size, int) and size >= -1, "size must be an integer >= -1." + assert size <= src_elements, f"size {size} exceeds source buffer size {src_elements}." + + assert direction.lower() in DIRECTION_MAP, f"Invalid direction string: {direction}" + + src_region = to_buffer_region(src) + dst_region = to_buffer_region(dst) + src_core_id = core_tuple_to_id(src_core) + dst_offset = 0 # Always 0 for now + + args = ( + src_region, + dst_region, + size, + dst_offset, + src_core_id, + DIRECTION_MAP[direction.lower()], + ) + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.comm_broadcast"), *args) + + +def put( + src: T.Buffer, + dst: T.Buffer, + src_core: tuple[int, int], + dst_core: tuple[int, int], + size: int = -1, +): + """Put data from a source buffer on a specific source core to a destination buffer on a specific destination core + by emitting the TIR intrinsic tl.tileop.comm_put. + Parameters + ---------- + src : T.Buffer + Source buffer containing data to put. + dst : T.Buffer + Destination buffer to receive the data. + src_core : tuple[int, int] + (row, col) coordinates of the source core on the target mesh. + dst_core : tuple[int, int] + (row, col) coordinates of the destination core on the target mesh. + size : int + Number of elements to put. If -1, the entire source buffer is used. + Returns + ------- + tir.Call + The TIR intrinsic call handle for `tl.tileop.comm_put`. + Examples + -------- + >>> put(A, B, (1, 2), (2, 3)) + """ + assert ( + src.dtype == dst.dtype + ), f"Source and destination buffer dtypes must match for broadcast. Got {src.dtype} vs {dst.dtype}." + if len(src.shape) != len(dst.shape): + raise ValueError( + "Source and destination buffer must have the same number of dimensions for broadcast." + ) + for i in range(len(src.shape)): + assert ( + src.shape[i] == dst.shape[i] or src.shape[i] == 1 or dst.shape[i] == 1 + ), f"Source buffer shape and destination buffer shape must match for broadcast. Got {src.shape} vs {dst.shape}." + + mesh_shape = get_target_mesh_shape("auto") + assert ( + isinstance(src_core, tuple) and len(src_core) == 2 + ), "src_core must be a tuple of (row, col)." + assert ( + 0 <= src_core[0] < mesh_shape["x"] + ), f"src_core row {src_core[0]} out of bounds for mesh shape {mesh_shape}." + assert ( + 0 <= src_core[1] < mesh_shape["y"] + ), f"src_core col {src_core[1]} out of bounds for mesh shape {mesh_shape}." + assert ( + isinstance(dst_core, tuple) and len(dst_core) == 2 + ), "dst_core must be a tuple of (row, col)." + assert ( + 0 <= dst_core[0] < mesh_shape["x"] + ), f"dst_core row {dst_core[0]} out of bounds for mesh shape {mesh_shape}." + assert ( + 0 <= dst_core[1] < mesh_shape["y"] + ), f"dst_core col {dst_core[1]} out of bounds for mesh shape {mesh_shape}." + src_elements = 1 + for dim in src.shape: + src_elements *= dim + assert isinstance(size, int) and size >= -1, "size must be an integer >= -1." + assert ( + size <= src_elements + ), f"size {size} exceeds source buffer size {src_elements}." + + src_region = to_buffer_region(src) + dst_region = to_buffer_region(dst) + src_core_id = core_tuple_to_id(src_core) + dst_core_id = core_tuple_to_id(dst_core) + args = (src_region, dst_region, size, src_core_id, dst_core_id) + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.comm_put"), *args) + + +def all_gather( + send_buffer: T.Buffer, + recv_buffer: T.Buffer, + direction: Literal["horizontal", "h", "vertical", "v", "all", "a"] = "all", + size: int = -1, +): + """Perform an all-gather operation from a send buffer to a receive buffer + by emitting the TIR intrinsic tl.tileop.comm_allgather. + Parameters + ---------- + send_buffer : T.Buffer + Buffer containing data to send. + recv_buffer : T.Buffer + Buffer to receive gathered data. + direction : Literal["horizontal", "h", "vertical", "v", "all", "a"] + Direction of all-gather: "horizontal" (or "h") for row-wise, "vertical" (or "v") for column-wise, + and "all" (or "a") for all cores. + size : int + Number of elements to send from each core. If -1, the entire send buffer is used. + Returns + ------- + tir.Call + The TIR intrinsic call handle for `tl.tileop.comm_allgather`. + Examples + -------- + >>> all_gather(A_local, C_local, direction="horizontal") + """ + assert direction.lower() in DIRECTION_MAP, f"Invalid direction string: {direction}" + + assert (send_buffer.dtype == recv_buffer.dtype + ), f"Source and destination buffer dtypes must match for broadcast. Got {send_buffer.dtype} vs {recv_buffer.dtype}." + mesh_shape = get_target_mesh_shape("auto") + + recv_num = 1 + if direction.lower() in ["horizontal", "h"]: + recv_num = mesh_shape["y"] + elif direction.lower() in ["vertical", "v"]: + recv_num = mesh_shape["x"] + elif direction.lower() in ["all", "a"]: + recv_num = mesh_shape["x"] * mesh_shape["y"] + + expected_recv_shape = [recv_num] + list(send_buffer.shape) + assert ( + list(recv_buffer.shape) == expected_recv_shape + ), f"Receive buffer shape must be {expected_recv_shape} to hold gathered data from {recv_num} cores, but got {recv_buffer.shape}." + + assert isinstance(size, int) and size >= -1, "size must be an integer >= -1." + send_elements = 1 + for dim in send_buffer.shape: + send_elements *= dim + assert ( + size <= send_elements + ), f"size {size} exceeds send buffer size {send_elements}." + + send_buffer_region = to_buffer_region(send_buffer) + recv_buffer_region = to_buffer_region(recv_buffer) + + direction_map = {"horizontal": 0, "h": 0, "vertical": 1, "v": 1, "all": 2, "a": 2} + args = ( + send_buffer_region, + recv_buffer_region, + direction_map[direction.lower()], + size, + ) + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.comm_allgather"), *args) + + +def barrier(group: Iterable[tuple[int, int]] | None = None): + """Insert a synchronization barrier among a group of cores. + + Parameters + ---------- + group : iterable of tuple[int, int] | None + Optional set of core coordinates to synchronize. If omitted, the + runtime's default participant set is used. + Optional set of core coordinates to synchronize. If omitted, the + runtime's default participant set is used. + + Returns + ------- + tir.Call + The TIR intrinsic call handle for `tl.comm_barrier`. + + Examples + -------- + >>> barrier() + >>> barrier(group=[(0,0),(0,1)]) + """ + if group is None: + return tir.call_intrin("handle", tir.op.Op.get("tl.comm_barrier")) + else: + group = [core_tuple_to_id(core_id) for core_id in group] + return tir.call_intrin("handle", tir.op.Op.get("tl.comm_barrier"), *group) + + +def fence(): + """Emit a memory/communication fence intrinsic. + + Returns + ------- + tir.Call + The TIR intrinsic call handle for `tl.comm_fence`. + + Examples + -------- + >>> fence() + """ + return tir.call_intrin("handle", tir.op.Op.get("tl.comm_fence")) From 6eb4f6233d5ce91ff4347b1fc26addeb0392ac96 Mon Sep 17 00:00:00 2001 From: xiaoyao-NKU <18946090101@163.com> Date: Thu, 22 Jan 2026 13:08:31 +0800 Subject: [PATCH 2/5] add Unit Test for comm_broadcast, comm_put, comm_all_gather --- .../language/test_tilelang_language_comm.py | 237 ++++++++++++++++++ 1 file changed, 237 insertions(+) create mode 100644 testing/python/language/test_tilelang_language_comm.py diff --git a/testing/python/language/test_tilelang_language_comm.py b/testing/python/language/test_tilelang_language_comm.py new file mode 100644 index 000000000..e6772bbb7 --- /dev/null +++ b/testing/python/language/test_tilelang_language_comm.py @@ -0,0 +1,237 @@ +import pytest + +import tilelang +import tilelang.language as T + +from tilelang import tvm as tvm +from tilelang.utils.target import determine_target + + +@pytest.mark.parametrize("M, N, block_M, block_N, dtype, accum_dtype", [ + (1024, 1024, 128, 128, "float16", "float"), +]) +def test_comm_python_api(M, N, block_M, block_N, dtype="float16", accum_dtype="float"): + func_str = """# from tvm.script import tir as T + +@T.prim_func +def main(A_handle: T.handle): + A = T.match_buffer(A_handle, (1024, 1024), "float16", strides=(1024, 1)) + # with T.block("root"): + bx = T.launch_thread("blockIdx.x", 8) + by = T.launch_thread("blockIdx.y", 8) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + with T.block("tilelang_root"): + T.reads(A[by * 128, bx * 128]) + T.writes() + A_local = T.alloc_buffer((128, 128), scope="local.fragment") + B_local = T.alloc_buffer((128, 128), scope="local.fragment") + C_local = T.alloc_buffer((16, 128, 128), scope="local.fragment") + T.copy(T.region(A[by * 128, bx * 128], 1, 128, 128), T.region(A_local[0, 0], 2, 128, 128), -1, T.bool(False), 0) + T.comm_broadcast(A_local[0:128, 0:128], B_local[0:128, 0:128], -1, 0, 6, 2) + T.comm_put(A_local[0:128, 0:128], B_local[0:128, 0:128], -1, 6, 11) + T.comm_allgather(A_local[0:128, 0:128], C_local[0:16, 0:128, 0:128], 2, -1)""" + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_local = T.alloc_fragment([block_M, block_N], accum_dtype) + B_local = T.alloc_fragment([block_M, block_N], accum_dtype) + C_local = T.alloc_fragment([16, block_M, block_N], accum_dtype) + T.copy(A[by * block_M, bx * block_N], A_local) + + T.comm.broadcast(A_local, B_local, (1, 2), direction="all") + T.comm.put(A_local, B_local, (1, 2), (2, 3)) + T.comm.all_gather(A_local, C_local, direction="all") + + assert main.script() == func_str, "The generated script does not match the expected output." + + +@pytest.mark.parametrize("M, N, block_M, block_N, dtype, accum_dtype", [ + (1024, 1024, 128, 128, "float16", "float"), +]) +def test_comm_broadcast_lower(M, N, block_M, block_N, dtype="float16", accum_dtype="float"): + func_str = """# from tvm.script import ir as I +# from tvm.script import tir as T + +@I.ir_module +class Module: + @T.prim_func + def main(A_handle: T.handle): + T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "mattr": ["device_mesh_nrow_4", "device_mesh_ncol_4"], "mcpu": "sunmmio-a4e", "tag": ""})}) + A = T.match_buffer(A_handle, (1024, 1024), "float16", strides=(1024, 1)) + # with T.block("root"): + bx = T.launch_thread("blockIdx.x", 8) + by = T.launch_thread("blockIdx.y", 8) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + with T.block("tilelang_root"): + T.reads(A[by * 128, bx * 128]) + T.writes() + A_local = T.alloc_buffer((128, 128), scope="local.fragment") + B_local = T.alloc_buffer((128, 128), scope="local.fragment") + for i in T.parallel(128): + for j in T.parallel(32): + for vec in T.vectorized(4): + A_local[i, j * 4 + vec] = T.Cast("float32", A[by * 128 + i, bx * 128 + (j * 4 + vec)]) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 6, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 2, 0, 2) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 6, 0, 2) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 10, 0, 2) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 14, 0, 2)""" + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_local = T.alloc_fragment([block_M, block_N], accum_dtype) + B_local = T.alloc_fragment([block_M, block_N], accum_dtype) + T.copy(A[by * block_M, bx * block_N], A_local) + + T.comm.broadcast(A_local, B_local, (1, 2), direction="all") + + mod = tvm.IRModule({'main': main}) + target = determine_target("Sunmmio", return_object=True) + with tvm.target.Target(target): + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tilelang.transform.LowerTileOp()(mod) + assert mod.script() == func_str, "The generated script does not match the expected output." + + +@pytest.mark.parametrize("M, N, block_M, block_N, dtype, accum_dtype", [ + (1024, 1024, 128, 128, "float16", "float"), +]) +def test_comm_put_lower(M, N, block_M, block_N, dtype="float16", accum_dtype="float"): + func_str = """# from tvm.script import ir as I +# from tvm.script import tir as T + +@I.ir_module +class Module: + @T.prim_func + def main(A_handle: T.handle): + T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "mattr": ["device_mesh_nrow_4", "device_mesh_ncol_4"], "mcpu": "sunmmio-a4e", "tag": ""})}) + A = T.match_buffer(A_handle, (1024, 1024), "float16", strides=(1024, 1)) + # with T.block("root"): + bx = T.launch_thread("blockIdx.x", 8) + by = T.launch_thread("blockIdx.y", 8) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + with T.block("tilelang_root"): + T.reads(A[by * 128, bx * 128]) + T.writes() + A_local = T.alloc_buffer((128, 128), scope="local.fragment") + B_local = T.alloc_buffer((128, 128), scope="local.fragment") + for i in T.parallel(128): + for j in T.parallel(32): + for vec in T.vectorized(4): + A_local[i, j * 4 + vec] = T.Cast("float32", A[by * 128 + i, bx * 128 + (j * 4 + vec)]) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 6, 1, 0, 1, 3) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 7, 0, 0, 1, 2)""" + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_local = T.alloc_fragment([block_M, block_N], accum_dtype) + B_local = T.alloc_fragment([block_M, block_N], accum_dtype) + T.copy(A[by * block_M, bx * block_N], A_local) + + T.comm.put(A_local, B_local, (1, 2), (2, 3)) + + mod = tvm.IRModule({'main': main}) + target = determine_target("Sunmmio", return_object=True) + with tvm.target.Target(target): + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tilelang.transform.LowerTileOp()(mod) + assert mod.script() == func_str, "The generated script does not match the expected output." + + +@pytest.mark.parametrize("M, N, block_M, block_N, dtype, accum_dtype", [ + (1024, 1024, 128, 128, "float16", "float"), +]) +def test_comm_all_gather_lower(M, N, block_M, block_N, dtype="float16", accum_dtype="float"): + func_str = """# from tvm.script import ir as I +# from tvm.script import tir as T + +@I.ir_module +class Module: + @T.prim_func + def main(A_handle: T.handle): + T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "mattr": ["device_mesh_nrow_4", "device_mesh_ncol_4"], "mcpu": "sunmmio-a4e", "tag": ""})}) + A = T.match_buffer(A_handle, (1024, 1024), "float16", strides=(1024, 1)) + # with T.block("root"): + bx = T.launch_thread("blockIdx.x", 8) + by = T.launch_thread("blockIdx.y", 8) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + with T.block("tilelang_root"): + T.reads(A[by * 128, bx * 128]) + T.writes() + A_local = T.alloc_buffer((128, 128), scope="local.fragment") + C_local = T.alloc_buffer((16, 128, 128), scope="local.fragment") + for i in T.parallel(128): + for j in T.parallel(32): + for vec in T.vectorized(4): + A_local[i, j * 4 + vec] = T.Cast("float32", A[by * 128 + i, bx * 128 + (j * 4 + vec)]) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 16384, 2), 16384, 0, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 16384, 16384, 2), 16384, 1, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 32768, 16384, 2), 16384, 2, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 49152, 16384, 2), 16384, 3, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 16384, 2), 16384, 4, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 81920, 16384, 2), 16384, 5, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 98304, 16384, 2), 16384, 6, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 114688, 16384, 2), 16384, 7, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 16384, 2), 16384, 8, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 147456, 16384, 2), 16384, 9, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 163840, 16384, 2), 16384, 10, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 180224, 16384, 2), 16384, 11, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 16384, 2), 16384, 12, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 212992, 16384, 2), 16384, 13, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 229376, 16384, 2), 16384, 14, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 245760, 16384, 2), 16384, 15, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 2), 65536, 0, 1, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 2), 65536, 4, 1, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 2), 65536, 8, 1, 2) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 2), 65536, 12, 1, 3) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 2), 65536, 1, 1, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 2), 65536, 5, 1, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 2), 65536, 9, 1, 2) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 2), 65536, 13, 1, 3) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 2), 65536, 2, 1, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 2), 65536, 6, 1, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 2), 65536, 10, 1, 2) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 2), 65536, 14, 1, 3) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 2), 65536, 3, 1, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 2), 65536, 7, 1, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 2), 65536, 11, 1, 2) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 2), 65536, 15, 1, 3)""" + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_local = T.alloc_fragment([block_M, block_N], accum_dtype) + C_local = T.alloc_fragment([16, block_M, block_N], accum_dtype) + T.copy(A[by * block_M, bx * block_N], A_local) + + T.comm.all_gather(A_local, C_local, direction="all") + mod = tvm.IRModule({'main': main}) + target = determine_target("Sunmmio", return_object=True) + with tvm.target.Target(target): + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tilelang.transform.LowerTileOp()(mod) + assert mod.script() == func_str, "The generated script does not match the expected output." + + +if __name__ == "__main__": + test_comm_python_api(1024, 1024, 128, 128, "float16", "float") + test_comm_broadcast_lower(1024, 1024, 128, 128, "float16", "float") + test_comm_put_lower(1024, 1024, 128, 128, "float16", "float") + test_comm_all_gather_lower(1024, 1024, 128, 128, "float16", "float") From 14e7ca9b247166526045da23f91e9c9ffffbdd85 Mon Sep 17 00:00:00 2001 From: xiaoyao-NKU <18946090101@163.com> Date: Thu, 22 Jan 2026 13:31:01 +0800 Subject: [PATCH 3/5] fix for pre-commot --- src/op/comm.cc | 12 +-- .../language/test_tilelang_language_comm.py | 51 +++++------ tilelang/language/__init__.py | 1 + tilelang/language/comm.py | 91 +++++++------------ 4 files changed, 64 insertions(+), 91 deletions(-) diff --git a/src/op/comm.cc b/src/op/comm.cc index 19b1edbc0..2248bacc9 100644 --- a/src/op/comm.cc +++ b/src/op/comm.cc @@ -85,7 +85,7 @@ LayoutMap BroadcastNode::InferLayout(const LayoutInferArgs &T, args.push_back(dst_expr); Copy copy_op = Copy(args); LayoutMap out_layout = copy_op->InferLayout(T, level); - return out_layout; + return out_layout; } // int get_target_mesh_nrows(Target target) { @@ -123,12 +123,13 @@ int get_target_mesh(Target target, int axis) { for (size_t i = 0; i < mattr.size(); i++) { std::string m = mattr[i]; if (m.find(axis_str) != std::string::npos) { - std::string s = m.substr(m.find_last_of('_') + 1);; + std::string s = m.substr(m.find_last_of('_') + 1); + ; try { x = std::stoi(s); - } catch (const std::invalid_argument& e) { + } catch (const std::invalid_argument &e) { x = -1; - } catch (const std::out_of_range& e) { + } catch (const std::out_of_range &e) { x = -1; } } @@ -138,7 +139,6 @@ int get_target_mesh(Target target, int axis) { return x; } - Stmt BroadcastNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Target target = T.target; ICHECK(TargetIsSunmmio(target)) << "Broadcast only supports SUNMMIO targets."; @@ -444,7 +444,7 @@ Layout AllgatherNode::ComputeLinearLayout(const Buffer &shared_tensor) const { LayoutMap AllgatherNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { - Buffer recv_buffer = NormalizeToBufferRegion(recv)->buffer; + Buffer recv_buffer = NormalizeToBufferRegion(recv)->buffer; Layout linear_layout = ComputeLinearLayout(recv_buffer); return Map({{recv_buffer, linear_layout}}); } diff --git a/testing/python/language/test_tilelang_language_comm.py b/testing/python/language/test_tilelang_language_comm.py index e6772bbb7..4817c90a9 100644 --- a/testing/python/language/test_tilelang_language_comm.py +++ b/testing/python/language/test_tilelang_language_comm.py @@ -34,13 +34,11 @@ def main(A_handle: T.handle): T.comm_allgather(A_local[0:128, 0:128], C_local[0:16, 0:128, 0:128], 2, -1)""" @T.prim_func - def main( - A: T.Tensor((M, N), dtype), - ): + def main(A: T.Tensor((M, N), dtype),): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_local = T.alloc_fragment([block_M, block_N], accum_dtype) - B_local = T.alloc_fragment([block_M, block_N], accum_dtype) - C_local = T.alloc_fragment([16, block_M, block_N], accum_dtype) + A_local = T.alloc_fragment([block_M, block_N], accum_dtype) + B_local = T.alloc_fragment([block_M, block_N], accum_dtype) + C_local = T.alloc_fragment([16, block_M, block_N], accum_dtype) T.copy(A[by * block_M, bx * block_N], A_local) T.comm.broadcast(A_local, B_local, (1, 2), direction="all") @@ -48,7 +46,7 @@ def main( T.comm.all_gather(A_local, C_local, direction="all") assert main.script() == func_str, "The generated script does not match the expected output." - + @pytest.mark.parametrize("M, N, block_M, block_N, dtype, accum_dtype", [ (1024, 1024, 128, 128, "float16", "float"), @@ -85,23 +83,21 @@ def main(A_handle: T.handle): T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 14, 0, 2)""" @T.prim_func - def main( - A: T.Tensor((M, N), dtype), - ): + def main(A: T.Tensor((M, N), dtype),): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_local = T.alloc_fragment([block_M, block_N], accum_dtype) - B_local = T.alloc_fragment([block_M, block_N], accum_dtype) + A_local = T.alloc_fragment([block_M, block_N], accum_dtype) + B_local = T.alloc_fragment([block_M, block_N], accum_dtype) T.copy(A[by * block_M, bx * block_N], A_local) T.comm.broadcast(A_local, B_local, (1, 2), direction="all") - + mod = tvm.IRModule({'main': main}) target = determine_target("Sunmmio", return_object=True) with tvm.target.Target(target): mod = tvm.tir.transform.BindTarget(target)(mod) mod = tilelang.transform.LowerTileOp()(mod) assert mod.script() == func_str, "The generated script does not match the expected output." - + @pytest.mark.parametrize("M, N, block_M, block_N, dtype, accum_dtype", [ (1024, 1024, 128, 128, "float16", "float"), @@ -133,25 +129,24 @@ def main(A_handle: T.handle): A_local[i, j * 4 + vec] = T.Cast("float32", A[by * 128 + i, bx * 128 + (j * 4 + vec)]) T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 6, 1, 0, 1, 3) T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 7, 0, 0, 1, 2)""" + @T.prim_func - def main( - A: T.Tensor((M, N), dtype), - ): + def main(A: T.Tensor((M, N), dtype),): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_local = T.alloc_fragment([block_M, block_N], accum_dtype) - B_local = T.alloc_fragment([block_M, block_N], accum_dtype) + A_local = T.alloc_fragment([block_M, block_N], accum_dtype) + B_local = T.alloc_fragment([block_M, block_N], accum_dtype) T.copy(A[by * block_M, bx * block_N], A_local) T.comm.put(A_local, B_local, (1, 2), (2, 3)) - + mod = tvm.IRModule({'main': main}) target = determine_target("Sunmmio", return_object=True) with tvm.target.Target(target): mod = tvm.tir.transform.BindTarget(target)(mod) mod = tilelang.transform.LowerTileOp()(mod) assert mod.script() == func_str, "The generated script does not match the expected output." - - + + @pytest.mark.parametrize("M, N, block_M, block_N, dtype, accum_dtype", [ (1024, 1024, 128, 128, "float16", "float"), ]) @@ -212,16 +207,16 @@ def main(A_handle: T.handle): T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 2), 65536, 7, 1, 1) T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 2), 65536, 11, 1, 2) T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 2), 65536, 15, 1, 3)""" + @T.prim_func - def main( - A: T.Tensor((M, N), dtype), - ): + def main(A: T.Tensor((M, N), dtype),): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_local = T.alloc_fragment([block_M, block_N], accum_dtype) - C_local = T.alloc_fragment([16, block_M, block_N], accum_dtype) + A_local = T.alloc_fragment([block_M, block_N], accum_dtype) + C_local = T.alloc_fragment([16, block_M, block_N], accum_dtype) T.copy(A[by * block_M, bx * block_N], A_local) - T.comm.all_gather(A_local, C_local, direction="all") + T.comm.all_gather(A_local, C_local, direction="all") + mod = tvm.IRModule({'main': main}) target = determine_target("Sunmmio", return_object=True) with tvm.target.Target(target): diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index ea77c0d5c..3802ae29b 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -108,6 +108,7 @@ ) from . import comm # noqa: F401 + def import_source(source: str | None = None): # source is the source code to be imported return block_attr({"pragma_import_c": source}) if source is not None else None diff --git a/tilelang/language/comm.py b/tilelang/language/comm.py index c42f26c2a..72d9329cf 100644 --- a/tilelang/language/comm.py +++ b/tilelang/language/comm.py @@ -13,7 +13,6 @@ import tilelang.language as T from tilelang.utils.language import to_buffer_region - DIRECTION_MAP = {"horizontal": 0, "h": 0, "vertical": 1, "v": 1, "all": 2, "a": 2} @@ -38,7 +37,7 @@ def get_target_mesh_shape(target: str = "auto") -> dict[str, int]: return {"x": 2, "y": 4} else: raise ValueError(f"Unknown target: {target}") - + def core_tuple_to_id(core_id: tuple[int, int]) -> int: """Convert 2D (row, col) coordinates on the mesh into a linear core id. @@ -60,12 +59,8 @@ def core_tuple_to_id(core_id: tuple[int, int]) -> int: """ mesh_shape = get_target_mesh_shape("auto") row, col = core_id - assert ( - 0 <= row < mesh_shape["x"] - ), f"Row {row} out of bounds for mesh shape {mesh_shape}." - assert ( - 0 <= col < mesh_shape["y"] - ), f"Col {col} out of bounds for mesh shape {mesh_shape}." + assert (0 <= row < mesh_shape["x"]), f"Row {row} out of bounds for mesh shape {mesh_shape}." + assert (0 <= col < mesh_shape["y"]), f"Col {col} out of bounds for mesh shape {mesh_shape}." core_id_value = row * mesh_shape["y"] + col return core_id_value @@ -117,19 +112,14 @@ def CoreId(core_id: int | tuple[int, int]): mesh_shape = get_target_mesh_shape("auto") if isinstance(core_id, tuple): row, col = core_id - assert ( - 0 <= row < mesh_shape["x"] - ), f"Row {row} out of bounds for mesh shape {mesh_shape}" - assert ( - 0 <= col < mesh_shape["y"] - ), f"Col {col} out of bounds for mesh shape {mesh_shape}" + assert (0 <= row < mesh_shape["x"]), f"Row {row} out of bounds for mesh shape {mesh_shape}" + assert (0 <= col < mesh_shape["y"]), f"Col {col} out of bounds for mesh shape {mesh_shape}" # Convert 2D coordinates into a linear core id. core_id_value = row * mesh_shape["x"] + col elif isinstance(core_id, int): core_id_value = core_id - assert ( - 0 <= core_id_value < mesh_shape["x"] * mesh_shape["y"] - ), f"Core ID {core_id_value} out of bounds for mesh shape {mesh_shape}" + assert (0 <= core_id_value < mesh_shape["x"] * mesh_shape["y"] + ), f"Core ID {core_id_value} out of bounds for mesh shape {mesh_shape}" else: raise ValueError("core_id must be either a tuple[int, int] or an int.") return tir.call_intrin("handle", tir.op.Op.get("tl.CoreId"), core_id_value) @@ -185,22 +175,19 @@ def broadcast( ), f"Source and destination buffer dtypes must match for broadcast. Got {src.dtype} vs {dst.dtype}." if len(src.shape) != len(dst.shape): raise ValueError( - "Source and destination buffer must have the same number of dimensions for broadcast." - ) + "Source and destination buffer must have the same number of dimensions for broadcast.") for i in range(len(src.shape)): - assert (src.shape[i] == dst.shape[i] or src.shape[i] == 1 or dst.shape[i] == 1 + assert ( + src.shape[i] == dst.shape[i] or src.shape[i] == 1 or dst.shape[i] == 1 ), f"Source buffer shape and destination buffer shape must match for broadcast. Got {src.shape} vs {dst.shape}." mesh_shape = get_target_mesh_shape("auto") - assert ( - isinstance(src_core, tuple) and len(src_core) == 2 - ), "src_core must be a tuple of (row, col)." - assert ( - 0 <= src_core[0] < mesh_shape["x"] - ), f"src_core row {src_core[0]} out of bounds for mesh shape {mesh_shape}." - assert ( - 0 <= src_core[1] < mesh_shape["y"] - ), f"src_core col {src_core[1]} out of bounds for mesh shape {mesh_shape}." + assert (isinstance(src_core, tuple) and + len(src_core) == 2), "src_core must be a tuple of (row, col)." + assert (0 <= src_core[0] < mesh_shape["x"] + ), f"src_core row {src_core[0]} out of bounds for mesh shape {mesh_shape}." + assert (0 <= src_core[1] < mesh_shape["y"] + ), f"src_core col {src_core[1]} out of bounds for mesh shape {mesh_shape}." src_elements = 1 for dim in src.shape: @@ -260,39 +247,30 @@ def put( ), f"Source and destination buffer dtypes must match for broadcast. Got {src.dtype} vs {dst.dtype}." if len(src.shape) != len(dst.shape): raise ValueError( - "Source and destination buffer must have the same number of dimensions for broadcast." - ) + "Source and destination buffer must have the same number of dimensions for broadcast.") for i in range(len(src.shape)): assert ( src.shape[i] == dst.shape[i] or src.shape[i] == 1 or dst.shape[i] == 1 ), f"Source buffer shape and destination buffer shape must match for broadcast. Got {src.shape} vs {dst.shape}." mesh_shape = get_target_mesh_shape("auto") - assert ( - isinstance(src_core, tuple) and len(src_core) == 2 - ), "src_core must be a tuple of (row, col)." - assert ( - 0 <= src_core[0] < mesh_shape["x"] - ), f"src_core row {src_core[0]} out of bounds for mesh shape {mesh_shape}." - assert ( - 0 <= src_core[1] < mesh_shape["y"] - ), f"src_core col {src_core[1]} out of bounds for mesh shape {mesh_shape}." - assert ( - isinstance(dst_core, tuple) and len(dst_core) == 2 - ), "dst_core must be a tuple of (row, col)." - assert ( - 0 <= dst_core[0] < mesh_shape["x"] - ), f"dst_core row {dst_core[0]} out of bounds for mesh shape {mesh_shape}." - assert ( - 0 <= dst_core[1] < mesh_shape["y"] - ), f"dst_core col {dst_core[1]} out of bounds for mesh shape {mesh_shape}." + assert (isinstance(src_core, tuple) and + len(src_core) == 2), "src_core must be a tuple of (row, col)." + assert (0 <= src_core[0] < mesh_shape["x"] + ), f"src_core row {src_core[0]} out of bounds for mesh shape {mesh_shape}." + assert (0 <= src_core[1] < mesh_shape["y"] + ), f"src_core col {src_core[1]} out of bounds for mesh shape {mesh_shape}." + assert (isinstance(dst_core, tuple) and + len(dst_core) == 2), "dst_core must be a tuple of (row, col)." + assert (0 <= dst_core[0] < mesh_shape["x"] + ), f"dst_core row {dst_core[0]} out of bounds for mesh shape {mesh_shape}." + assert (0 <= dst_core[1] < mesh_shape["y"] + ), f"dst_core col {dst_core[1]} out of bounds for mesh shape {mesh_shape}." src_elements = 1 for dim in src.shape: src_elements *= dim assert isinstance(size, int) and size >= -1, "size must be an integer >= -1." - assert ( - size <= src_elements - ), f"size {size} exceeds source buffer size {src_elements}." + assert (size <= src_elements), f"size {size} exceeds source buffer size {src_elements}." src_region = to_buffer_region(src) dst_region = to_buffer_region(dst) @@ -330,8 +308,9 @@ def all_gather( >>> all_gather(A_local, C_local, direction="horizontal") """ assert direction.lower() in DIRECTION_MAP, f"Invalid direction string: {direction}" - - assert (send_buffer.dtype == recv_buffer.dtype + + assert ( + send_buffer.dtype == recv_buffer.dtype ), f"Source and destination buffer dtypes must match for broadcast. Got {send_buffer.dtype} vs {recv_buffer.dtype}." mesh_shape = get_target_mesh_shape("auto") @@ -352,9 +331,7 @@ def all_gather( send_elements = 1 for dim in send_buffer.shape: send_elements *= dim - assert ( - size <= send_elements - ), f"size {size} exceeds send buffer size {send_elements}." + assert (size <= send_elements), f"size {size} exceeds send buffer size {send_elements}." send_buffer_region = to_buffer_region(send_buffer) recv_buffer_region = to_buffer_region(recv_buffer) From 82e9700d2c014b165db5d99b91ba34d787ad968f Mon Sep 17 00:00:00 2001 From: xiaoyao-NKU <18946090101@163.com> Date: Thu, 29 Jan 2026 10:41:39 +0800 Subject: [PATCH 4/5] Add all_reduce TileOp and some changes addressing the suggestions on GitHub. --- src/op/comm.cc | 503 +++++++++++++++--- src/op/comm.h | 136 +++-- .../language/test_tilelang_language_comm.py | 147 ++++- tilelang/language/comm.py | 173 ++++-- 4 files changed, 769 insertions(+), 190 deletions(-) diff --git a/src/op/comm.cc b/src/op/comm.cc index 2248bacc9..135c44c0d 100644 --- a/src/op/comm.cc +++ b/src/op/comm.cc @@ -52,8 +52,8 @@ TIR_DEFINE_TL_BUILTIN(broadcast_) using namespace tir; -Broadcast::Broadcast(Array args) { - ObjectPtr node = tvm::ffi::make_object(); +BroadcastOp::BroadcastOp(Array args) { + ObjectPtr node = tvm::ffi::make_object(); node->src_expr = args[0]; node->dst_expr = args[1]; Array rgs[2]; @@ -73,13 +73,13 @@ Broadcast::Broadcast(Array args) { data_ = std::move(node); } -TileOperator BroadcastNode::Clone() const { - auto op = tvm::ffi::make_object(*this); - return Broadcast(op); +TileOperator BroadcastOpNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return BroadcastOp(op); } -LayoutMap BroadcastNode::InferLayout(const LayoutInferArgs &T, - InferLevel level) const { +LayoutMap BroadcastOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { Array args; args.push_back(src_expr); args.push_back(dst_expr); @@ -88,27 +88,6 @@ LayoutMap BroadcastNode::InferLayout(const LayoutInferArgs &T, return out_layout; } -// int get_target_mesh_nrows(Target target) { -// auto mattr = target->GetAttr>("mattr").value(); -// int x = 0; -// for (size_t i = 0; i < mattr.size(); i++) { -// std::string m = mattr[i]; -// if (m.find("device_mesh_nrow_") != std::string::npos) { -// std::string s = m.substr(m.find_last_of('_') + 1);; -// try { -// x = std::stoi(s); -// } catch (const std::invalid_argument& e) { -// x = -1; -// } catch (const std::out_of_range& e) { -// x = -1; -// } -// } -// } -// ICHECK(x != 0) << "Device mesh row number not found."; -// ICHECK(x > 0) << "Invalid device mesh row number: "; -// return x; -// } - int get_target_mesh(Target target, int axis) { auto mattr = target->GetAttr>("mattr").value(); int x = 0; @@ -139,7 +118,8 @@ int get_target_mesh(Target target, int axis) { return x; } -Stmt BroadcastNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { +Stmt BroadcastOpNode::Lower(const LowerArgs &T, + arith::Analyzer *analyzer) const { Target target = T.target; ICHECK(TargetIsSunmmio(target)) << "Broadcast only supports SUNMMIO targets."; int mesh_x = get_target_mesh(target, 0); @@ -222,16 +202,14 @@ Stmt BroadcastNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { args.push_back(1); // direction: vertical Stmt broadcast = Evaluate(Call(DataType::Handle(), broadcast_(), args)); seq.push_back(broadcast); - // herizontal broadcast + // horizontal broadcast for (int i = 0; i < mesh_x; i++) { Array args; args.push_back(dst.access_ptr(1, DataType::Handle(), 1, 0, dst_elements)); args.push_back(dst.access_ptr(2, DataType::Handle(), 1, 0, dst_elements)); args.push_back(Downcast(broadcast_elements)); args.push_back(int(i * mesh_y) + src_core_y); - args.push_back(0); // direction: horeizontal - args.push_back( - IntImm(DataType::Int(32), src_core_y)); // mask: current core only + args.push_back(0); // direction: horizontal Stmt broadcast = Evaluate(Call(DataType::Handle(), broadcast_(), args)); seq.push_back(broadcast); } @@ -239,13 +217,13 @@ Stmt BroadcastNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } } -TIR_REGISTER_TL_TILE_OP(Broadcast, comm_broadcast) +TIR_REGISTER_TL_TILE_OP(BroadcastOp, comm_broadcast) .set_num_inputs(6) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -Put::Put(Array args) { - ObjectPtr node = tvm::ffi::make_object(); +PutOp::PutOp(Array args) { + ObjectPtr node = tvm::ffi::make_object(); node->src_expr = args[0]; node->dst_expr = args[1]; Array rgs[2]; @@ -263,13 +241,13 @@ Put::Put(Array args) { data_ = std::move(node); } -TileOperator PutNode::Clone() const { - auto op = tvm::ffi::make_object(*this); - return Put(op); +TileOperator PutOpNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return PutOp(op); } -LayoutMap PutNode::InferLayout(const LayoutInferArgs &T, - InferLevel level) const { +LayoutMap PutOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { Array args; args.push_back(src_expr); args.push_back(dst_expr); @@ -278,7 +256,7 @@ LayoutMap PutNode::InferLayout(const LayoutInferArgs &T, return out_layout; } -Stmt PutNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { +Stmt PutOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Target target = T.target; ICHECK(TargetIsSunmmio(target)) << "Put only supports SUNMMIO targets."; int mesh_x = get_target_mesh(target, 0); @@ -308,7 +286,7 @@ Stmt PutNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { << "Source buffer size larger than destination buffer size: " << src_elements << " vs " << dst_elements; ICHECK(size->value <= Downcast(src_elements)->value) - << "Broadcast size larger than data size: " << size->value << " vs " + << "Put size larger than data size: " << size->value << " vs " << Downcast(src_elements)->value; // check for size @@ -320,12 +298,12 @@ Stmt PutNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } ICHECK((Downcast(broadcast_elements)->value) <= Downcast(src_elements)->value) - << "Broadcast size Larger than source buffer size: " + << "Put size Larger than source buffer size: " << (Downcast(broadcast_elements)->value) << " vs " << Downcast(src_elements)->value; ICHECK((Downcast(broadcast_elements)->value) <= Downcast(dst_elements)->value) - << "Broadcast size larger than destination buffer size: " + << "Put size larger than destination buffer size: " << (Downcast(broadcast_elements)->value) << " vs " << Downcast(dst_elements)->value; @@ -406,13 +384,13 @@ Stmt PutNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } } -TIR_REGISTER_TL_TILE_OP(Put, comm_put) +TIR_REGISTER_TL_TILE_OP(PutOp, comm_put) .set_num_inputs(5) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -Allgather::Allgather(Array args) { - ObjectPtr node = tvm::ffi::make_object(); +AllgatherOp::AllgatherOp(Array args) { + ObjectPtr node = tvm::ffi::make_object(); node->send = args[0]; node->recv = args[1]; node->direction = Downcast(args[2])->value; @@ -420,36 +398,86 @@ Allgather::Allgather(Array args) { data_ = std::move(node); } -TileOperator AllgatherNode::Clone() const { - auto op = tvm::ffi::make_object(*this); - return Allgather(op); +TileOperator AllgatherOpNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return AllgatherOp(op); } -Layout AllgatherNode::ComputeLinearLayout(const Buffer &shared_tensor) const { - Array input_size = shared_tensor->shape; - Array forward_vars; - for (size_t i = 0; i < input_size.size(); i++) { - forward_vars.push_back(InputPlaceholder(i)); - } - // [i, j] -> [i // 256, j // 256, i % 256, j % 256] - Array forward_index; - for (size_t i = 0; i < input_size.size(); i++) { - forward_index.push_back(FloorDiv(forward_vars[i], 256)); - } - for (size_t i = 0; i < input_size.size(); i++) { - forward_index.push_back(FloorMod(forward_vars[i], 256)); +// Not yet complete; it will be further refined later +LayoutMap AllgatherOpNode::ComputeLayout(const LayoutInferArgs &T, + InferLevel level, Buffer src, + Buffer dst) const { + if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" && + T.layout_map.count(src)) { + auto src_layout = T.layout_map[src].as().value(); + + PrimExpr src_rep_extent = src_layout->ReplicateExtent(); + + Array fwd; + fwd.push_back(InputPlaceholder(0)); + for (int i = 0; i < static_cast(src->shape.size()); i++) { + fwd.push_back(InputPlaceholder(i + 1)); + } + auto thd = src_layout->ForwardThread(fwd, std::nullopt); + + Fragment dst_layout = + Fragment(dst->shape, {}, thd, src_rep_extent, std::nullopt) + ->CondenseReplicateVar() + ->BindThreadRange(T.thread_bounds); + + if (!T.layout_map.count(dst)) + return {{dst, dst_layout}}; + else { + // Check if computed layout is compatible with existing: the existing one + // must strictly contains the computed layout + auto orig_dst_layout = + T.layout_map.Get(dst).value().as().value(); + ICHECK(dst_layout->InputDim() == orig_dst_layout->InputDim()); + Array indices; + indices.reserve(dst_layout->InputDim()); + arith::Analyzer inner_analyzer; + for (int i = 0; i < dst_layout->InputDim(); ++i) { + auto x = InputPlaceholder(i); + indices.push_back(x); + // should be literal - literal = 0, any analyzer will work + ICHECK(is_zero(inner_analyzer.Simplify( + dst_layout->InputShape()[i] - orig_dst_layout->InputShape()[i]))); + inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i])); + } + + ICHECK(as_const_int(dst_layout->ReplicateExtent())); + ICHECK(as_const_int(src_layout->ReplicateExtent())); + auto dst_rep = *as_const_int(dst_layout->ReplicateExtent()); + auto src_rep = *as_const_int(src_layout->ReplicateExtent()); + if (dst_rep < src_rep || + !ProveFragmentContains(orig_dst_layout, dst_layout, indices, indices, + inner_analyzer)) { + std::ostringstream oss; + oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. " + << src << "\nLHS = " << src_layout->DebugOutput() + << "\nRHS = " << orig_dst_layout->DebugOutput() + << "\nYou may need to use a shared memory to transform the " + "layout"; + throw LayoutConflictException(oss.str()); + } + + if (dst_rep > src_rep) { + return {{dst, dst_layout}}; + } + } } - return Layout(input_size, forward_index); + return {}; } -LayoutMap AllgatherNode::InferLayout(const LayoutInferArgs &T, - InferLevel level) const { +LayoutMap AllgatherOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + Buffer src_buffer = NormalizeToBufferRegion(send)->buffer; Buffer recv_buffer = NormalizeToBufferRegion(recv)->buffer; - Layout linear_layout = ComputeLinearLayout(recv_buffer); - return Map({{recv_buffer, linear_layout}}); + return ComputeLayout(T, level, src_buffer, recv_buffer); } -Stmt AllgatherNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { +Stmt AllgatherOpNode::Lower(const LowerArgs &T, + arith::Analyzer *analyzer) const { Target target = T.target; ICHECK(TargetIsSunmmio(target)) << "Allgather only supports SUNMMIO targets."; int mesh_x = get_target_mesh(target, 0); @@ -503,7 +531,7 @@ Stmt AllgatherNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { args.push_back(IntImm(DataType::Int(32), j) * send_elements); // offset args.push_back(IntImm(DataType::Int(32), i * mesh_y + j)); // src_core args.push_back(0); // direction: horizontal - Broadcast bcast = Broadcast(args); + BroadcastOp bcast = BroadcastOp(args); Stmt bcast_stmt = bcast->Lower(T, analyzer); bcast_stmts.push_back(bcast_stmt); } @@ -518,7 +546,7 @@ Stmt AllgatherNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { args.push_back(IntImm(DataType::Int(32), i) * send_elements); // offset args.push_back(IntImm(DataType::Int(32), i * mesh_y + j)); // src_core args.push_back(1); // direction: vertical - Broadcast bcast = Broadcast(args); + BroadcastOp bcast = BroadcastOp(args); Stmt bcast_stmt = bcast->Lower(T, analyzer); bcast_stmts.push_back(bcast_stmt); } @@ -535,7 +563,7 @@ Stmt AllgatherNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { send_elements); // offset args.push_back(IntImm(DataType::Int(32), i * mesh_y + j)); // src_core args.push_back(0); // direction: horizontal - Broadcast bcast = Broadcast(args); + BroadcastOp bcast = BroadcastOp(args); Stmt bcast_stmt = bcast->Lower(T, analyzer); bcast_stmts.push_back(bcast_stmt); } @@ -559,8 +587,7 @@ Stmt AllgatherNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { IntImm(DataType::Int(32), mesh_y) * send_elements)); args.push_back(IntImm(DataType::Int(32), allgather_size)); // size args.push_back(IntImm(DataType::Int(32), i * mesh_y + j)); // src_core - args.push_back(1); // direction: vertical - args.push_back(IntImm(DataType::Int(32), i)); // mask: current row only + args.push_back(1); // direction: vertical Stmt bcast_stmt = Evaluate(Call(DataType::Handle(), broadcast_(), args)); bcast_stmts.push_back(bcast_stmt); @@ -570,15 +597,333 @@ Stmt AllgatherNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return SeqStmt::Flatten(bcast_stmts); } -TIR_REGISTER_TL_TILE_OP(Allgather, comm_allgather) +TIR_REGISTER_TL_TILE_OP(AllgatherOp, comm_allgather) .set_num_inputs(4) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +AllreduceOp::AllreduceOp(Array args) { + ObjectPtr node = tvm::ffi::make_object(); + node->src = args[0]; + node->dst = args[1]; + node->row_allgather = args[2]; + node->col_allgather = args[3]; + + node->type = Downcast(args[4]); + node->direction = Downcast(args[5])->value; + node->dim = Downcast(args[6]); + node->clear = Downcast(args[7]); + if (args.size() > 8) { + node->dst_copy = args[8]; + } + data_ = std::move(node); +} + +TileOperator AllreduceOpNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return AllreduceOp(op); +} + +// Not yet complete; it will be further refined later +LayoutMap AllreduceOpNode::ComputeLayout(const LayoutInferArgs &T, + InferLevel level, Buffer src, + Buffer dst, int dim) const { + if (level >= InferLevel::kStrict) + return {}; + + if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" && + T.layout_map.count(src)) { + auto src_layout = T.layout_map[src].as().value(); + + PrimExpr indice_rep_extent = src->shape[dim]; + PrimExpr src_rep_extent = src_layout->ReplicateExtent(); + PrimExpr dest_buffer_rep_extent = indice_rep_extent * src_rep_extent; + + Array fwd; + fwd.push_back(InputPlaceholder(0)); + for (int i = 0; i < static_cast(src->shape.size()); i++) { + if (i == dim) { + ; + } else if (i < dim) { + fwd.push_back(InputPlaceholder(i + 1)); + } else if (i > dim) { + fwd.push_back(InputPlaceholder(i - 1 + 1)); + } + } + auto thd = src_layout->ForwardThread( + fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); + + // Ensure the thread count is divisible by the replicate extent. + // Otherwise, we cannot infer a valid fragment<->fragment layout. + { + arith::Analyzer analyzer; + PrimExpr num_threads = T.thread_bounds->extent; + // Though the dest_buffer_rep_extent will be compressed at + // CondenseReplicateVar, we need to check the divisibility here to avoid + // the issue that the thread count is not divisible by the replicate + // extent. + if (!analyzer.CanProve(FloorMod(num_threads, dest_buffer_rep_extent) == + 0) && + !analyzer.CanProve(FloorMod(dest_buffer_rep_extent, num_threads) == + 0)) { + ICHECK(false) << "ReduceOp fragment layout inference failed: " + "num_threads % replicate_extent != 0. " + << "This mapping requires the block's thread count to be " + "divisible by the " + << "replicate extent. " + << "Try one of: (1) choose a thread block size divisible " + "by replicate_extent; " + << "(2) pick a different reduce dimension or adjust the " + "source fragment layout; " + << "Details: num_threads=" << num_threads + << ", replicate_extent=" << indice_rep_extent + << ", src=" << src << ", dst=" << dst; + } + } + + Fragment dst_layout = + Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt) + ->CondenseReplicateVar() + ->BindThreadRange(T.thread_bounds); + + if (!T.layout_map.count(dst)) + return {{dst, dst_layout}}; + else { + // Check if computed layout is compatible with existing: the existing one + // must strictly contains the computed layout + auto orig_dst_layout = + T.layout_map.Get(dst).value().as().value(); + ICHECK(dst_layout->InputDim() == orig_dst_layout->InputDim()); + Array indices; + indices.reserve(dst_layout->InputDim()); + arith::Analyzer inner_analyzer; + for (int i = 0; i < dst_layout->InputDim(); ++i) { + auto x = InputPlaceholder(i); + indices.push_back(x); + // should be literal - literal = 0, any analyzer will work + ICHECK(is_zero(inner_analyzer.Simplify( + dst_layout->InputShape()[i] - orig_dst_layout->InputShape()[i]))); + inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i])); + } + + ICHECK(as_const_int(dst_layout->ReplicateExtent())); + ICHECK(as_const_int(src_layout->ReplicateExtent())); + auto dst_rep = *as_const_int(dst_layout->ReplicateExtent()); + auto src_rep = *as_const_int(src_layout->ReplicateExtent()); + if (dst_rep < src_rep || + !ProveFragmentContains(orig_dst_layout, dst_layout, indices, indices, + inner_analyzer)) { + std::ostringstream oss; + oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. " + << src << "\nLHS = " << src_layout->DebugOutput() + << "\nRHS = " << orig_dst_layout->DebugOutput() + << "\nYou may need to use a shared memory to transform the " + "layout"; + throw LayoutConflictException(oss.str()); + } + + if (dst_rep > src_rep) { + return {{dst, dst_layout}}; + } + } + } + return {}; +} + +LayoutMap AllreduceOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + LayoutMap lm; + + Array dst_layout_args; + dst_layout_args.push_back(src); + dst_layout_args.push_back(dst); + dst_layout_args.push_back(type); + dst_layout_args.push_back(dim); + dst_layout_args.push_back(clear); + ReduceOp dst_layout_op = ReduceOp(dst_layout_args); + LayoutMap dst_layout_map = dst_layout_op->InferLayout(T, InferLevel::kFree); + for (const auto &kv : dst_layout_map) { + lm.Set(kv.first, kv.second); + } + + if (dst_copy.defined()) { + Array dst_copy_layout_args; + dst_copy_layout_args.push_back(src); + dst_copy_layout_args.push_back(dst_copy); + dst_copy_layout_args.push_back(type); + dst_copy_layout_args.push_back(dim); + dst_copy_layout_args.push_back(clear); + ReduceOp dst_copy_layout_op = ReduceOp(dst_copy_layout_args); + LayoutMap dst_copy_layout_map = + dst_copy_layout_op->InferLayout(T, InferLevel::kFree); + for (const auto &kv : dst_copy_layout_map) { + lm.Set(kv.first, kv.second); + } + } + + Buffer row_allgather_buffer = NormalizeToBufferRegion(row_allgather)->buffer; + LayoutMap row_allgather_layout = + ComputeLayout(T, InferLevel::kFree, NormalizeToBufferRegion(src)->buffer, + row_allgather_buffer, dim->value); + for (const auto &kv : row_allgather_layout) { + lm.Set(kv.first, kv.second); + } + + Buffer col_allgather_buffer = NormalizeToBufferRegion(col_allgather)->buffer; + LayoutMap col_allgather_layout = + ComputeLayout(T, InferLevel::kFree, NormalizeToBufferRegion(src)->buffer, + col_allgather_buffer, dim->value); + for (const auto &kv : col_allgather_layout) { + lm.Set(kv.first, kv.second); + } + + return lm; +} + +Stmt AllreduceOpNode::Lower(const LowerArgs &T, + arith::Analyzer *analyzer) const { + Target target = T.target; + ICHECK(TargetIsSunmmio(target)) << "Allreduce only supports SUNMMIO targets."; + int mesh_x = get_target_mesh(target, 0); + int mesh_y = get_target_mesh(target, 1); + + Array stmts; + + if (clear.as().value() == true) { + // Local reduce to dst + Array local_reduce_args; + local_reduce_args.push_back(src); + local_reduce_args.push_back(dst); + local_reduce_args.push_back(type); + local_reduce_args.push_back(dim); + local_reduce_args.push_back(IntImm(DataType::Int(32), 1)); // clear = true + ReduceOp local_reduce_op = ReduceOp(local_reduce_args); + Stmt local_reduce_stmt = local_reduce_op->Lower(T, analyzer); + stmts.push_back(local_reduce_stmt); + + if (direction == 0 or direction == 2) { // row-wise + // Allgather dst in rows to row_allgather + Array row_allgather_args; + row_allgather_args.push_back(dst); + row_allgather_args.push_back(row_allgather); + row_allgather_args.push_back( + IntImm(DataType::Int(32), 0)); // direction = horizontal + row_allgather_args.push_back(IntImm(DataType::Int(32), -1)); // size + AllgatherOp row_allgather_op = AllgatherOp(row_allgather_args); + Stmt row_allgather_stmt = row_allgather_op->Lower(T, analyzer); + stmts.push_back(row_allgather_stmt); + + // Local reduce from row_allgather to dst + Array row_reduce_args; + row_reduce_args.push_back(row_allgather); + row_reduce_args.push_back(dst); + row_reduce_args.push_back(type); + row_reduce_args.push_back(IntImm(DataType::Int(32), 0)); // dim + row_reduce_args.push_back(IntImm(DataType::Int(32), 1)); // clear = true + ReduceOp row_reduce_op = ReduceOp(row_reduce_args); + Stmt row_reduce_stmt = row_reduce_op->Lower(T, analyzer); + stmts.push_back(row_reduce_stmt); + } + + if (direction == 1 or direction == 2) { // column-wise + // Allgather dst in columns to col_allgather + Array col_allgather_args; + col_allgather_args.push_back(dst); + col_allgather_args.push_back(col_allgather); + col_allgather_args.push_back( + IntImm(DataType::Int(32), 1)); // direction = vertical + col_allgather_args.push_back(IntImm(DataType::Int(32), -1)); // size + AllgatherOp col_allgather_op = AllgatherOp(col_allgather_args); + Stmt col_allgather_stmt = col_allgather_op->Lower(T, analyzer); + stmts.push_back(col_allgather_stmt); + + // Local reduce from col_allgather to dst + Array col_reduce_args; + col_reduce_args.push_back(col_allgather); + col_reduce_args.push_back(dst); + col_reduce_args.push_back(type); + col_reduce_args.push_back(IntImm(DataType::Int(32), 0)); // dim + col_reduce_args.push_back(IntImm(DataType::Int(32), 1)); // clear = true + ReduceOp col_reduce_op = ReduceOp(col_reduce_args); + Stmt col_reduce_stmt = col_reduce_op->Lower(T, analyzer); + stmts.push_back(col_reduce_stmt); + } + } else { + // Local reduce to dst_copy + Array local_reduce_args; + local_reduce_args.push_back(src); + local_reduce_args.push_back(dst_copy); + local_reduce_args.push_back(type); + local_reduce_args.push_back(dim); + local_reduce_args.push_back(IntImm(DataType::Int(32), 1)); // clear = true + ReduceOp local_reduce_op = ReduceOp(local_reduce_args); + Stmt local_reduce_stmt = local_reduce_op->Lower(T, analyzer); + stmts.push_back(local_reduce_stmt); + + if (direction == 0 or direction == 2) { // row-wise + // Allgather dst in rows to row_allgather + Array row_allgather_args; + row_allgather_args.push_back(dst_copy); + row_allgather_args.push_back(row_allgather); + row_allgather_args.push_back( + IntImm(DataType::Int(32), 0)); // direction = horizontal + row_allgather_args.push_back(IntImm(DataType::Int(32), -1)); // size + AllgatherOp row_allgather_op = AllgatherOp(row_allgather_args); + Stmt row_allgather_stmt = row_allgather_op->Lower(T, analyzer); + stmts.push_back(row_allgather_stmt); + + // Local reduce from row_allgather to dst + Array row_reduce_args; + row_reduce_args.push_back(row_allgather); + row_reduce_args.push_back(direction == 0 ? dst : dst_copy); + row_reduce_args.push_back(type); + row_reduce_args.push_back(IntImm(DataType::Int(32), 0)); // dim + row_reduce_args.push_back(IntImm( + DataType::Int(32), + direction == 0 ? 0 : 1)); // clear = direction == 0 ? false : true + ReduceOp row_reduce_op = ReduceOp(row_reduce_args); + Stmt row_reduce_stmt = row_reduce_op->Lower(T, analyzer); + stmts.push_back(row_reduce_stmt); + } + + if (direction == 1 or direction == 2) { // column-wise + // Allgather dst in columns to col_allgather + Array col_allgather_args; + col_allgather_args.push_back(dst_copy); + col_allgather_args.push_back(col_allgather); + col_allgather_args.push_back( + IntImm(DataType::Int(32), 1)); // direction = vertical + col_allgather_args.push_back(IntImm(DataType::Int(32), -1)); // size + AllgatherOp col_allgather_op = AllgatherOp(col_allgather_args); + Stmt col_allgather_stmt = col_allgather_op->Lower(T, analyzer); + stmts.push_back(col_allgather_stmt); + + // Local reduce from col_allgather to dst + Array col_reduce_args; + col_reduce_args.push_back(col_allgather); + col_reduce_args.push_back(dst); + col_reduce_args.push_back(type); + col_reduce_args.push_back(IntImm(DataType::Int(32), 0)); // dim + col_reduce_args.push_back(IntImm(DataType::Int(32), 0)); // clear = false + ReduceOp col_reduce_op = ReduceOp(col_reduce_args); + Stmt col_reduce_stmt = col_reduce_op->Lower(T, analyzer); + stmts.push_back(col_reduce_stmt); + } + } + + return SeqStmt::Flatten(stmts); +} + +TIR_REGISTER_TL_TILE_OP(AllreduceOp, comm_allreduce) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TVM_FFI_STATIC_INIT_BLOCK() { - PutNode::RegisterReflection(); - BroadcastNode::RegisterReflection(); - AllgatherNode::RegisterReflection(); + PutOpNode::RegisterReflection(); + BroadcastOpNode::RegisterReflection(); + AllgatherOpNode::RegisterReflection(); + AllreduceOpNode::RegisterReflection(); } } // namespace tl diff --git a/src/op/comm.h b/src/op/comm.h index e0b4ffe43..b79eaa597 100644 --- a/src/op/comm.h +++ b/src/op/comm.h @@ -20,7 +20,7 @@ TVM_DLL const Op &broadcast_(); using namespace tir; -class BroadcastNode : public TileOperatorNode { +class BroadcastOpNode : public TileOperatorNode { public: Buffer src, dst; Array src_range, dst_range; @@ -29,40 +29,38 @@ class BroadcastNode : public TileOperatorNode { IntImm dst_offset; IntImm src_core; int direction; - // Array group; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.comm_broadcast", BroadcastNode, + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.comm_broadcast", BroadcastOpNode, TileOperatorNode); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("src", &BroadcastNode::src) - .def_ro("dst", &BroadcastNode::dst) - .def_ro("src_range", &BroadcastNode::src_range) - .def_ro("dst_range", &BroadcastNode::dst_range) - .def_ro("src_core", &BroadcastNode::src_core) - .def_ro("direction", &BroadcastNode::direction) - .def_ro("size", &BroadcastNode::size) - .def_ro("dst_offset", &BroadcastNode::dst_offset); - // .def_ro("group", &BroadcastNode::group); + refl::ObjectDef() + .def_ro("src", &BroadcastOpNode::src) + .def_ro("dst", &BroadcastOpNode::dst) + .def_ro("src_range", &BroadcastOpNode::src_range) + .def_ro("dst_range", &BroadcastOpNode::dst_range) + .def_ro("src_core", &BroadcastOpNode::src_core) + .def_ro("direction", &BroadcastOpNode::direction) + .def_ro("size", &BroadcastOpNode::size) + .def_ro("dst_offset", &BroadcastOpNode::dst_offset); } - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + TileOperator Clone() const; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; - TileOperator Clone() const; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; }; -class Broadcast : public TileOperator { +class BroadcastOp : public TileOperator { public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Broadcast, TileOperator, - BroadcastNode); - TVM_DLL Broadcast(Array args); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BroadcastOp, TileOperator, + BroadcastOpNode); + TVM_DLL BroadcastOp(Array args); static const Op &Get(); }; -class PutNode : public TileOperatorNode { +class PutOpNode : public TileOperatorNode { public: Buffer src, dst; Array src_range, dst_range; @@ -70,63 +68,107 @@ class PutNode : public TileOperatorNode { IntImm src_core, dst_core; IntImm size; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.comm_put", PutNode, TileOperatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.comm_put", PutOpNode, TileOperatorNode); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("src", &PutNode::src) - .def_ro("dst", &PutNode::dst) - .def_ro("src_range", &PutNode::src_range) - .def_ro("dst_range", &PutNode::dst_range) - .def_ro("src_core", &PutNode::src_core) - .def_ro("dst_core", &PutNode::dst_core) - .def_ro("size", &PutNode::size); + refl::ObjectDef() + .def_ro("src", &PutOpNode::src) + .def_ro("dst", &PutOpNode::dst) + .def_ro("src_range", &PutOpNode::src_range) + .def_ro("dst_range", &PutOpNode::dst_range) + .def_ro("src_core", &PutOpNode::src_core) + .def_ro("dst_core", &PutOpNode::dst_core) + .def_ro("size", &PutOpNode::size); } - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + TileOperator Clone() const; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; - TileOperator Clone() const; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; }; -class Put : public TileOperator { +class PutOp : public TileOperator { public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Put, TileOperator, PutNode); - TVM_DLL Put(Array args); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PutOp, TileOperator, PutOpNode); + TVM_DLL PutOp(Array args); static const Op &Get(); }; -class AllgatherNode : public TileOperatorNode { +class AllgatherOpNode : public TileOperatorNode { public: PrimExpr send, recv; int direction; IntImm size; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.comm_allgather", AllgatherNode, + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.comm_allgather", AllgatherOpNode, TileOperatorNode); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("send", &AllgatherNode::send) - .def_ro("recv", &AllgatherNode::recv) - .def_ro("direction", &AllgatherNode::direction) - .def_ro("size", &AllgatherNode::size); + refl::ObjectDef() + .def_ro("send", &AllgatherOpNode::send) + .def_ro("recv", &AllgatherOpNode::recv) + .def_ro("direction", &AllgatherOpNode::direction) + .def_ro("size", &AllgatherOpNode::size); } - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; - Layout ComputeLinearLayout(const Buffer &shared_tensor) const; + TileOperator Clone() const; + LayoutMap ComputeLayout(const LayoutInferArgs &T, InferLevel level, + Buffer src, Buffer dst) const; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; +}; + +class AllgatherOp : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AllgatherOp, TileOperator, + AllgatherOpNode); + TVM_DLL AllgatherOp(Array args); + static const Op &Get(); +}; + +class AllreduceOpNode : public TileOperatorNode { +public: + PrimExpr src, dst; + PrimExpr row_allgather, col_allgather; + PrimExpr dst_copy; + StringImm type; + int direction; + IntImm dim; + IntImm clear; + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.comm_allreduce", AllreduceOpNode, + TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("src", &AllreduceOpNode::src) + .def_ro("dst", &AllreduceOpNode::dst) + .def_ro("row_allgather", &AllreduceOpNode::row_allgather) + .def_ro("col_allgather", &AllreduceOpNode::col_allgather) + .def_ro("type", &AllreduceOpNode::type) + .def_ro("dim", &AllreduceOpNode::dim) + .def_ro("clear", &AllreduceOpNode::clear) + .def_ro("direction", &AllreduceOpNode::direction) + .def_ro("dst_copy", &AllreduceOpNode::dst_copy); + } + TileOperator Clone() const; + LayoutMap ComputeLayout(const LayoutInferArgs &T, InferLevel level, + Buffer src, Buffer dst, int dim) const; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; }; -class Allgather : public TileOperator { +class AllreduceOp : public TileOperator { public: - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Allgather, TileOperator, - AllgatherNode); - TVM_DLL Allgather(Array args); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AllreduceOp, TileOperator, + AllreduceOpNode); + TVM_DLL AllreduceOp(Array args); static const Op &Get(); }; diff --git a/testing/python/language/test_tilelang_language_comm.py b/testing/python/language/test_tilelang_language_comm.py index 4817c90a9..da8ce75f0 100644 --- a/testing/python/language/test_tilelang_language_comm.py +++ b/testing/python/language/test_tilelang_language_comm.py @@ -77,10 +77,10 @@ def main(A_handle: T.handle): for vec in T.vectorized(4): A_local[i, j * 4 + vec] = T.Cast("float32", A[by * 128 + i, bx * 128 + (j * 4 + vec)]) T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 6, 1) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 2, 0, 2) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 6, 0, 2) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 10, 0, 2) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 14, 0, 2)""" + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 2, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 6, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 10, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 14, 0)""" @T.prim_func def main(A: T.Tensor((M, N), dtype),): @@ -191,22 +191,22 @@ def main(A_handle: T.handle): T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 212992, 16384, 2), 16384, 13, 0) T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 229376, 16384, 2), 16384, 14, 0) T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 245760, 16384, 2), 16384, 15, 0) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 2), 65536, 0, 1, 0) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 2), 65536, 4, 1, 1) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 2), 65536, 8, 1, 2) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 2), 65536, 12, 1, 3) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 2), 65536, 1, 1, 0) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 2), 65536, 5, 1, 1) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 2), 65536, 9, 1, 2) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 2), 65536, 13, 1, 3) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 2), 65536, 2, 1, 0) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 2), 65536, 6, 1, 1) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 2), 65536, 10, 1, 2) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 2), 65536, 14, 1, 3) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 2), 65536, 3, 1, 0) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 2), 65536, 7, 1, 1) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 2), 65536, 11, 1, 2) - T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 2), 65536, 15, 1, 3)""" + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 2), 65536, 0, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 2), 65536, 4, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 2), 65536, 8, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 2), 65536, 12, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 2), 65536, 1, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 2), 65536, 5, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 2), 65536, 9, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 2), 65536, 13, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 2), 65536, 2, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 2), 65536, 6, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 2), 65536, 10, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 2), 65536, 14, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 2), 65536, 3, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 2), 65536, 7, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 2), 65536, 11, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 2), 65536, 15, 1)""" @T.prim_func def main(A: T.Tensor((M, N), dtype),): @@ -225,8 +225,115 @@ def main(A: T.Tensor((M, N), dtype),): assert mod.script() == func_str, "The generated script does not match the expected output." +@pytest.mark.parametrize("M, N, block_M, block_N, dtype, accum_dtype", [ + (1024 * 128, 1024 * 128, 1024, 1024, "float16", "float"), +]) +def test_comm_all_reduce_lower(M, N, block_M, block_N, dtype="float16", accum_dtype="float"): + func_str = """# from tvm.script import ir as I +# from tvm.script import tir as T + +@I.ir_module +class Module: + @T.prim_func + def main(A_handle: T.handle): + T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "mattr": ["device_mesh_nrow_4", "device_mesh_ncol_4"], "mcpu": "sunmmio-a4e", "tag": ""})}) + A = T.match_buffer(A_handle, (131072, 131072), "float16", strides=(131072, 1)) + with T.block("root"): + T.reads() + T.writes() + A_local = T.Buffer((8192,), scope="local") + T.block_attr({"layout_map": {A_local: metadata["tl.Fragment"][0]}}) + bx = T.launch_thread("blockIdx.x", 128) + by = T.launch_thread("blockIdx.y", 128) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + with T.block("tilelang_root"): + T.reads(A[by * 1024, bx * 1024]) + T.writes() + T.block_attr({"layout_map": {A_local: metadata["tl.Fragment"][0]}}) + A_local = T.alloc_buffer((8192,), data=A_local.data, scope="local") + E_local = T.alloc_buffer((1024,), scope="local") + buffer = T.alloc_buffer((32,), scope="local") + buffer_1 = T.alloc_buffer((32,), scope="local") + buffer_2 = T.alloc_buffer((1024,), scope="local") + workspace = T.alloc_buffer((128,), scope="shared.dyn") + for i in T.parallel(1024): + for j in T.parallel(256): + for vec in T.vectorized(4): + A_local[i * 8 + (j * 4 + vec) // 512 * 4 + (j * 4 + vec) % 4] = T.Cast("float32", A[by * 1024 + i, bx * 1024 + (j * 4 + vec)]) + for i in T.unroll(1024, annotations={"pragma_unroll_explicit": T.bool(False)}): + buffer_2[i] = T.float32(0.0) + for rv in T.unroll(8, annotations={"pragma_unroll_explicit": T.bool(False)}): + buffer_2[i] = buffer_2[i] + A_local[i * 8 + rv % 2 * 4 + rv // 2] + buffer_2[i] = T.call_extern("float32", "tl::AllReduce::run", buffer_2[i], T.tvm_access_ptr(T.type_annotation("float32"), workspace.data, 0, 128, 2)) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 0, 1024, 2), 1024, 0, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 1024, 1024, 2), 1024, 1, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 2048, 1024, 2), 1024, 2, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 3072, 1024, 2), 1024, 3, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 0, 1024, 2), 1024, 4, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 1024, 1024, 2), 1024, 5, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 2048, 1024, 2), 1024, 6, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 3072, 1024, 2), 1024, 7, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 0, 1024, 2), 1024, 8, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 1024, 1024, 2), 1024, 9, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 2048, 1024, 2), 1024, 10, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 3072, 1024, 2), 1024, 11, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 0, 1024, 2), 1024, 12, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 1024, 1024, 2), 1024, 13, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 2048, 1024, 2), 1024, 14, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 3072, 1024, 2), 1024, 15, 0) + for i in T.unroll(1024, annotations={"pragma_unroll_explicit": T.bool(False)}): + buffer_2[i] = T.float32(0.0) + for rv in T.unroll(4, annotations={"pragma_unroll_explicit": T.bool(False)}): + buffer_2[i] = buffer_2[i] + buffer[rv * 8 + i // 512 * 4 + i % 4] + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 0, 1024, 2), 1024, 0, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 1024, 1024, 2), 1024, 4, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 2048, 1024, 2), 1024, 8, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 3072, 1024, 2), 1024, 12, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 0, 1024, 2), 1024, 1, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 1024, 1024, 2), 1024, 5, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 2048, 1024, 2), 1024, 9, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 3072, 1024, 2), 1024, 13, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 0, 1024, 2), 1024, 2, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 1024, 1024, 2), 1024, 6, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 2048, 1024, 2), 1024, 10, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 3072, 1024, 2), 1024, 14, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 0, 1024, 2), 1024, 3, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 1024, 1024, 2), 1024, 7, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 2048, 1024, 2), 1024, 11, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 3072, 1024, 2), 1024, 15, 1) + E_local_clear = T.allocate([1024], "float32", "local") + for i in T.unroll(1024, annotations={"pragma_unroll_explicit": T.bool(False)}): + E_local_clear_1 = T.Buffer((1024,), data=E_local_clear, scope="local") + E_local_clear_1[i] = T.float32(0.0) + for rv in T.unroll(4, annotations={"pragma_unroll_explicit": T.bool(False)}): + E_local_clear_1[i] = E_local_clear_1[i] + buffer_1[rv * 8 + i // 512 * 4 + i % 4] + E_local[i] = E_local[i] + E_local_clear_1[i] + +# Metadata omitted. Use show_meta=True in script() method to show it.""" + + @T.prim_func + def main(A: T.Tensor((M, N), dtype),): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_local = T.alloc_fragment([block_M, block_N], accum_dtype) + E_local = T.alloc_fragment([block_M], accum_dtype) + T.copy(A[by * block_M, bx * block_N], A_local) + + T.comm.all_reduce(A_local, E_local, "sum", "all", dim=-1, clear=False) + + mod = tvm.IRModule({'main': main}) + target = determine_target("Sunmmio", return_object=True) + with tvm.target.Target(target): + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tilelang.transform.LayoutInference()(mod) + mod = tilelang.transform.LowerTileOp()(mod) + assert mod.script() == func_str, "The generated script does not match the expected output." + + if __name__ == "__main__": test_comm_python_api(1024, 1024, 128, 128, "float16", "float") test_comm_broadcast_lower(1024, 1024, 128, 128, "float16", "float") test_comm_put_lower(1024, 1024, 128, 128, "float16", "float") test_comm_all_gather_lower(1024, 1024, 128, 128, "float16", "float") + test_comm_all_reduce_lower(1024 * 128, 1024 * 128, 1024, 1024, "float16", "float") diff --git a/tilelang/language/comm.py b/tilelang/language/comm.py index 72d9329cf..bf4d5192a 100644 --- a/tilelang/language/comm.py +++ b/tilelang/language/comm.py @@ -11,32 +11,28 @@ from tvm import tir import tilelang.language as T -from tilelang.utils.language import to_buffer_region +from tilelang.utils.language import ( + to_buffer_region,) + +from tilelang.carver.arch.driver import get_sunmmio_device_mesh_config DIRECTION_MAP = {"horizontal": 0, "h": 0, "vertical": 1, "v": 1, "all": 2, "a": 2} +REDUCE_TYPE_LIST = ( + "sum", + "abssum", + "max", + "min", + "absmax", + "bitand", + "bitor", + "bitxor", +) -def get_target_mesh_shape(target: str = "auto") -> dict[str, int]: - """Get the shape of the target mesh as a dictionary with 'x' and 'y' keys. - Args: - target: The target mesh type. Supported values are - 'sunmmio-a4e', 'sunmmio-a4e-lite', and 'auto'. If 'auto' is specified, - the function defaults to 'sunmmio-a4e'. - Returns: - A dictionary with integer keys 'x' and 'y' representing - the 2D mesh size in each dimension. - Raises: - ValueError: If an unknown target is specified. - """ - if target == "auto": - target = "sunmmio-a4e" - - if target == "sunmmio-a4e": - return {"x": 4, "y": 4} - elif target == "sunmmio-a4e-lite": - return {"x": 2, "y": 4} - else: - raise ValueError(f"Unknown target: {target}") +def get_target_mesh_shape() -> dict[str, int]: + """Get the target mesh shape as a dictionary with 'x' and 'y' keys.""" + nrow, ncol = get_sunmmio_device_mesh_config() + return {"x": nrow, "y": ncol} def core_tuple_to_id(core_id: tuple[int, int]) -> int: @@ -55,9 +51,9 @@ def core_tuple_to_id(core_id: tuple[int, int]) -> int: Notes ----- The conversion uses the current target mesh shape obtained via - get_target_mesh_shape("auto"). + get_target_mesh_shape(). """ - mesh_shape = get_target_mesh_shape("auto") + mesh_shape = get_target_mesh_shape() row, col = core_id assert (0 <= row < mesh_shape["x"]), f"Row {row} out of bounds for mesh shape {mesh_shape}." assert (0 <= col < mesh_shape["y"]), f"Col {col} out of bounds for mesh shape {mesh_shape}." @@ -81,9 +77,9 @@ def core_id_to_tuple(core_id: tir.Call) -> tuple[int, int]: Notes ----- The conversion uses the current target mesh shape obtained via - get_target_mesh_shape("auto"). + get_target_mesh_shape(). """ - mesh_shape = get_target_mesh_shape("auto") + mesh_shape = get_target_mesh_shape() core_id_value = core_id row = core_id_value // mesh_shape["y"] col = core_id_value % mesh_shape["y"] @@ -109,13 +105,9 @@ def CoreId(core_id: int | tuple[int, int]): AssertionError, ValueError If the provided coordinates are out of bounds or the type is invalid. """ - mesh_shape = get_target_mesh_shape("auto") + mesh_shape = get_target_mesh_shape() if isinstance(core_id, tuple): - row, col = core_id - assert (0 <= row < mesh_shape["x"]), f"Row {row} out of bounds for mesh shape {mesh_shape}" - assert (0 <= col < mesh_shape["y"]), f"Col {col} out of bounds for mesh shape {mesh_shape}" - # Convert 2D coordinates into a linear core id. - core_id_value = row * mesh_shape["x"] + col + core_id_value = core_tuple_to_id(core_id) elif isinstance(core_id, int): core_id_value = core_id assert (0 <= core_id_value < mesh_shape["x"] * mesh_shape["y"] @@ -181,7 +173,7 @@ def broadcast( src.shape[i] == dst.shape[i] or src.shape[i] == 1 or dst.shape[i] == 1 ), f"Source buffer shape and destination buffer shape must match for broadcast. Got {src.shape} vs {dst.shape}." - mesh_shape = get_target_mesh_shape("auto") + mesh_shape = get_target_mesh_shape() assert (isinstance(src_core, tuple) and len(src_core) == 2), "src_core must be a tuple of (row, col)." assert (0 <= src_core[0] < mesh_shape["x"] @@ -193,7 +185,7 @@ def broadcast( for dim in src.shape: src_elements *= dim assert isinstance(size, int) and size >= -1, "size must be an integer >= -1." - assert size <= src_elements, f"size {size} exceeds source buffer size {src_elements}." + assert (size <= src_elements), f"size {size} exceeds source buffer size {src_elements}." assert direction.lower() in DIRECTION_MAP, f"Invalid direction string: {direction}" @@ -244,16 +236,16 @@ def put( """ assert ( src.dtype == dst.dtype - ), f"Source and destination buffer dtypes must match for broadcast. Got {src.dtype} vs {dst.dtype}." + ), f"Source and destination buffer dtypes must match for put. Got {src.dtype} vs {dst.dtype}." if len(src.shape) != len(dst.shape): raise ValueError( - "Source and destination buffer must have the same number of dimensions for broadcast.") + "Source and destination buffer must have the same number of dimensions for put.") for i in range(len(src.shape)): assert ( src.shape[i] == dst.shape[i] or src.shape[i] == 1 or dst.shape[i] == 1 - ), f"Source buffer shape and destination buffer shape must match for broadcast. Got {src.shape} vs {dst.shape}." + ), f"Source buffer shape and destination buffer shape must be compatible for put. Got {src.shape} vs {dst.shape}." - mesh_shape = get_target_mesh_shape("auto") + mesh_shape = get_target_mesh_shape() assert (isinstance(src_core, tuple) and len(src_core) == 2), "src_core must be a tuple of (row, col)." assert (0 <= src_core[0] < mesh_shape["x"] @@ -311,8 +303,8 @@ def all_gather( assert ( send_buffer.dtype == recv_buffer.dtype - ), f"Source and destination buffer dtypes must match for broadcast. Got {send_buffer.dtype} vs {recv_buffer.dtype}." - mesh_shape = get_target_mesh_shape("auto") + ), f"Source and destination buffer dtypes must match for all_gather. Got {send_buffer.dtype} vs {recv_buffer.dtype}." + mesh_shape = get_target_mesh_shape() recv_num = 1 if direction.lower() in ["horizontal", "h"]: @@ -336,16 +328,111 @@ def all_gather( send_buffer_region = to_buffer_region(send_buffer) recv_buffer_region = to_buffer_region(recv_buffer) - direction_map = {"horizontal": 0, "h": 0, "vertical": 1, "v": 1, "all": 2, "a": 2} args = ( send_buffer_region, recv_buffer_region, - direction_map[direction.lower()], + DIRECTION_MAP[direction.lower()], size, ) return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.comm_allgather"), *args) +def all_reduce( + buffer: T.Buffer, + out: T.Buffer, + reduce_type: str, + direction: Literal["horizontal", "h", "vertical", "v", "all", "a"], + dim: int = -1, + clear: bool = True, +): + """Perform an all-reduce operation on a buffer and store the result in an output buffer + by emitting the TIR intrinsic tl.tileop.comm_allreduce. + Parameters + ---------- + buffer : T.Buffer + Input buffer containing data to reduce. + out : T.Buffer + Output buffer to store the reduced result. + reduce_type : str + Type of reduction operation (e.g., "sum", "max", etc.). + direction : Literal["horizontal", "h", "vertical", "v", "all", "a"] + Direction of all-reduce: "horizontal" (or "h") for row-wise, "vertical" (or "v") for column-wise, + and "all" (or "a") for all cores. + dim : int + Dimension along which to perform the reduction. Default is -1 (last dimension). + clear : bool + Whether to clear the output buffer before reduction. Default is True. + Returns + ------- + tir.Call + The TIR intrinsic call handle for `tl.tileop.comm_allreduce`. + Examples + -------- + >>> all_reduce(A_local, E_local, "sum", "all", dim=-1, clear=False) + """ + assert (isinstance(dim, int) and dim >= -1 and dim < len( + buffer.shape)), f"dim {dim} out of bounds for buffer with {len(buffer.shape)} dimensions." + if dim == -1: + dim = len(buffer.shape) - 1 + + expected_shapes = [ + buffer.shape[:dim] + buffer.shape[dim + 1:], + buffer.shape[:dim] + [1] + buffer.shape[dim + 1:], + ] + if list(out.shape) not in expected_shapes: + expected_shapes_str = " or ".join(map(str, expected_shapes)) + raise ValueError( + f"Invalid reduce output shape, buffer shape is {buffer.shape}, dim is {dim}, " + f"output shape is {out.shape}, expected shapes are {expected_shapes_str}") + + reduce_type = reduce_type.lower() + assert (reduce_type in REDUCE_TYPE_LIST + ), f"Reduction op must be one of {REDUCE_TYPE_LIST}, but got {reduce_type}." + + assert direction.lower() in DIRECTION_MAP, f"Invalid direction string: {direction}" + assert clear in [True, False], "clear must be a boolean value." + + mesh_shape = get_target_mesh_shape() + + # Create temporary buffers for row and column allgather results + row_allgather = T.alloc_fragment(list([mesh_shape["x"]] + out.shape), out.dtype) + col_allgather = T.alloc_fragment(list([mesh_shape["y"]] + out.shape), out.dtype) + + buffer_region = to_buffer_region(buffer) + out_region = to_buffer_region(out) + row_allgather_region = to_buffer_region(row_allgather) + col_allgather_region = to_buffer_region(col_allgather) + + args = ( + buffer_region, + out_region, + row_allgather_region, + col_allgather_region, + reduce_type, + DIRECTION_MAP[direction.lower()], + dim, + clear, + ) + + # If not clearing, allocate an output copy buffer to hold intermediate results + if not clear: + out_copy = T.alloc_fragment(list(out.shape), out.dtype) + out_copy_region = to_buffer_region(out_copy) + args = ( + buffer_region, + out_region, + row_allgather_region, + col_allgather_region, + reduce_type, + DIRECTION_MAP[direction.lower()], + dim, + clear, + out_copy_region, + ) + + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.comm_allreduce"), *args) + + def barrier(group: Iterable[tuple[int, int]] | None = None): """Insert a synchronization barrier among a group of cores. @@ -354,8 +441,6 @@ def barrier(group: Iterable[tuple[int, int]] | None = None): group : iterable of tuple[int, int] | None Optional set of core coordinates to synchronize. If omitted, the runtime's default participant set is used. - Optional set of core coordinates to synchronize. If omitted, the - runtime's default participant set is used. Returns ------- From 65fda750e157ee10e5c53a0dfbd3fc23a189e769 Mon Sep 17 00:00:00 2001 From: xiaoyao-NKU <18946090101@163.com> Date: Fri, 30 Jan 2026 09:47:34 +0800 Subject: [PATCH 5/5] Some changes about all_reduce --- src/op/comm.cc | 4 ++++ testing/python/language/test_tilelang_language_comm.py | 10 +++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/op/comm.cc b/src/op/comm.cc index 135c44c0d..82ec32872 100644 --- a/src/op/comm.cc +++ b/src/op/comm.cc @@ -787,6 +787,10 @@ Stmt AllreduceOpNode::Lower(const LowerArgs &T, int mesh_x = get_target_mesh(target, 0); int mesh_y = get_target_mesh(target, 1); + ICHECK(direction == 0 || direction == 1 || direction == 2) + << "Invalid allreduce direction " << direction + << ", must be 0 (row-wise) or 1 (column-wise) or 2 (all)."; + Array stmts; if (clear.as().value() == true) { diff --git a/testing/python/language/test_tilelang_language_comm.py b/testing/python/language/test_tilelang_language_comm.py index da8ce75f0..fc63996bf 100644 --- a/testing/python/language/test_tilelang_language_comm.py +++ b/testing/python/language/test_tilelang_language_comm.py @@ -10,7 +10,7 @@ @pytest.mark.parametrize("M, N, block_M, block_N, dtype, accum_dtype", [ (1024, 1024, 128, 128, "float16", "float"), ]) -def test_comm_python_api(M, N, block_M, block_N, dtype="float16", accum_dtype="float"): +def test_comm_python_api(M, N, block_M, block_N, dtype, accum_dtype): func_str = """# from tvm.script import tir as T @T.prim_func @@ -51,7 +51,7 @@ def main(A: T.Tensor((M, N), dtype),): @pytest.mark.parametrize("M, N, block_M, block_N, dtype, accum_dtype", [ (1024, 1024, 128, 128, "float16", "float"), ]) -def test_comm_broadcast_lower(M, N, block_M, block_N, dtype="float16", accum_dtype="float"): +def test_comm_broadcast_lower(M, N, block_M, block_N, dtype, accum_dtype): func_str = """# from tvm.script import ir as I # from tvm.script import tir as T @@ -102,7 +102,7 @@ def main(A: T.Tensor((M, N), dtype),): @pytest.mark.parametrize("M, N, block_M, block_N, dtype, accum_dtype", [ (1024, 1024, 128, 128, "float16", "float"), ]) -def test_comm_put_lower(M, N, block_M, block_N, dtype="float16", accum_dtype="float"): +def test_comm_put_lower(M, N, block_M, block_N, dtype, accum_dtype): func_str = """# from tvm.script import ir as I # from tvm.script import tir as T @@ -150,7 +150,7 @@ def main(A: T.Tensor((M, N), dtype),): @pytest.mark.parametrize("M, N, block_M, block_N, dtype, accum_dtype", [ (1024, 1024, 128, 128, "float16", "float"), ]) -def test_comm_all_gather_lower(M, N, block_M, block_N, dtype="float16", accum_dtype="float"): +def test_comm_all_gather_lower(M, N, block_M, block_N, dtype, accum_dtype): func_str = """# from tvm.script import ir as I # from tvm.script import tir as T @@ -228,7 +228,7 @@ def main(A: T.Tensor((M, N), dtype),): @pytest.mark.parametrize("M, N, block_M, block_N, dtype, accum_dtype", [ (1024 * 128, 1024 * 128, 1024, 1024, "float16", "float"), ]) -def test_comm_all_reduce_lower(M, N, block_M, block_N, dtype="float16", accum_dtype="float"): +def test_comm_all_reduce_lower(M, N, block_M, block_N, dtype, accum_dtype): func_str = """# from tvm.script import ir as I # from tvm.script import tir as T