diff --git a/src/op/comm.cc b/src/op/comm.cc new file mode 100644 index 000000000..82ec32872 --- /dev/null +++ b/src/op/comm.cc @@ -0,0 +1,934 @@ +/*! + * \file tl/op/comm.cc + * \brief Implementation of Inter-core Communication Operators + */ + +#include "comm.h" + +#include +#include +#include + +#include "../target/utils.h" +#include "copy.h" +#include "reduce.h" +#include "utils.h" + +namespace tvm { +namespace tl { + +#define TIR_DEFINE_TL_BUILTIN(OpName) \ + const Op &OpName() { \ + static const Op &op = Op::Get("tl." #OpName); \ + return op; \ + } \ + TVM_REGISTER_OP("tl." #OpName) \ + .set_attr("TScriptPrinterName", #OpName) +TIR_DEFINE_TL_BUILTIN(comm_barrier) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(comm_fence) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(CoreId).set_num_inputs(1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(comm_current_core) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(comm_is_current_core) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(broadcast_) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); +// src_buffer, dst_buffer, size(IntImm), src_core(IntImm) +// direction(0: horizontal, 1: vertical), +// *mask(optional: IntImm list of core ids to exclude) + +using namespace tir; + +BroadcastOp::BroadcastOp(Array args) { + ObjectPtr node = tvm::ffi::make_object(); + node->src_expr = args[0]; + node->dst_expr = args[1]; + Array rgs[2]; + Buffer bf[2]; + for (int i = 0; i < 2; i++) { + auto region = NormalizeToBufferRegion(args[i]); + + rgs[i] = region->region; + bf[i] = region->buffer; + } + std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); + std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); + node->size = Downcast(args[2]); + node->dst_offset = Downcast(args[3]); + node->src_core = Downcast(args[4]); + node->direction = Downcast(args[5])->value; + data_ = std::move(node); +} + +TileOperator BroadcastOpNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return BroadcastOp(op); +} + +LayoutMap BroadcastOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + Array args; + args.push_back(src_expr); + args.push_back(dst_expr); + Copy copy_op = Copy(args); + LayoutMap out_layout = copy_op->InferLayout(T, level); + return out_layout; +} + +int get_target_mesh(Target target, int axis) { + auto mattr = target->GetAttr>("mattr").value(); + int x = 0; + std::string axis_str; + if (axis == 0) { + axis_str = "device_mesh_ncol_"; + } else if (axis == 1) { + axis_str = "device_mesh_nrow_"; + } else { + LOG(FATAL) << "Invalid axis " << axis << " for getting mesh dimension."; + } + for (size_t i = 0; i < mattr.size(); i++) { + std::string m = mattr[i]; + if (m.find(axis_str) != std::string::npos) { + std::string s = m.substr(m.find_last_of('_') + 1); + ; + try { + x = std::stoi(s); + } catch (const std::invalid_argument &e) { + x = -1; + } catch (const std::out_of_range &e) { + x = -1; + } + } + } + ICHECK(x != 0) << axis_str << " not found."; + ICHECK(x > 0) << "Invalid " << axis_str; + return x; +} + +Stmt BroadcastOpNode::Lower(const LowerArgs &T, + arith::Analyzer *analyzer) const { + Target target = T.target; + ICHECK(TargetIsSunmmio(target)) << "Broadcast only supports SUNMMIO targets."; + int mesh_x = get_target_mesh(target, 0); + int mesh_y = get_target_mesh(target, 1); + + // check for valid core id + ICHECK(src_core->value >= 0 and src_core->value < mesh_x * mesh_y) + << "Source core id " << src_core->value << " out of range [0, " + << mesh_x * mesh_y << ")"; + + // check for src and dst buffer sizes + PrimExpr src_elements = 1; + for (size_t i = 0; i < src_range.size(); i++) { + src_elements *= src_range[i]->extent; + } + src_elements = analyzer->Simplify(src_elements); + PrimExpr dst_elements = 1; + for (size_t i = 0; i < dst_range.size(); i++) { + dst_elements *= dst_range[i]->extent; + } + dst_elements = analyzer->Simplify(dst_elements); + ICHECK(Downcast(src_elements)->value <= + Downcast(dst_elements)->value) + << "Source buffer size larger than destination buffer size: " + << src_elements << " vs " << dst_elements; + ICHECK(size->value <= Downcast(src_elements)->value) + << "Broadcast size larger than data size: " << size->value << " vs " + << Downcast(src_elements)->value; + + // check for size and dst_offset + PrimExpr broadcast_elements; + if (size->value < 0) { + broadcast_elements = src_elements; + } else { + broadcast_elements = size; + } + ICHECK((Downcast(broadcast_elements)->value) <= + Downcast(src_elements)->value) + << "Broadcast size Larger than source buffer size: " + << (Downcast(broadcast_elements)->value) << " vs " + << Downcast(src_elements)->value; + ICHECK((Downcast(broadcast_elements)->value + dst_offset->value) <= + Downcast(dst_elements)->value) + << "Broadcast size + dst_offset larger than destination buffer size: " + << (Downcast(broadcast_elements)->value + dst_offset->value) + << " vs " << Downcast(dst_elements)->value; + + // check for valid direction + if (direction != 0 and direction != 1 and direction != 2) { + LOG(FATAL) << "Invalid broadcast direction " << direction + << ", must be 0 (horizontal) or 1 (vertical) or 2 (all)."; + } + + // all checks passed, generate the call + PrimExpr src_addr = src.access_ptr(1, DataType::Handle(), 1, 0, src_elements); + PrimExpr dst_addr = + dst.access_ptr(2, DataType::Handle(), 1, + Downcast(dst_offset->value), src_elements); + int src_core_y = src_core->value % mesh_y; + + if (direction == 0 or direction == 1) { + // 1D broadcast + Array args; + args.push_back(src_addr); + args.push_back(dst_addr); + args.push_back(Downcast(broadcast_elements)); + args.push_back(src_core); + args.push_back(direction); + Stmt broadcast = Evaluate(Call(DataType::Handle(), broadcast_(), args)); + return broadcast; + } else { + // 2D broadcast + Array seq; + // vertical broadcast + Array args; + args.push_back(src_addr); + args.push_back(dst_addr); + args.push_back(Downcast(broadcast_elements)); + args.push_back(src_core); + args.push_back(1); // direction: vertical + Stmt broadcast = Evaluate(Call(DataType::Handle(), broadcast_(), args)); + seq.push_back(broadcast); + // horizontal broadcast + for (int i = 0; i < mesh_x; i++) { + Array args; + args.push_back(dst.access_ptr(1, DataType::Handle(), 1, 0, dst_elements)); + args.push_back(dst.access_ptr(2, DataType::Handle(), 1, 0, dst_elements)); + args.push_back(Downcast(broadcast_elements)); + args.push_back(int(i * mesh_y) + src_core_y); + args.push_back(0); // direction: horizontal + Stmt broadcast = Evaluate(Call(DataType::Handle(), broadcast_(), args)); + seq.push_back(broadcast); + } + return SeqStmt::Flatten(seq); + } +} + +TIR_REGISTER_TL_TILE_OP(BroadcastOp, comm_broadcast) + .set_num_inputs(6) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +PutOp::PutOp(Array args) { + ObjectPtr node = tvm::ffi::make_object(); + node->src_expr = args[0]; + node->dst_expr = args[1]; + Array rgs[2]; + Buffer bf[2]; + for (int i = 0; i < 2; i++) { + auto region = NormalizeToBufferRegion(args[i]); + rgs[i] = region->region; + bf[i] = region->buffer; + } + std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); + std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); + node->size = Downcast(args[2]); + node->src_core = Downcast(args[3]); + node->dst_core = Downcast(args[4]); + data_ = std::move(node); +} + +TileOperator PutOpNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return PutOp(op); +} + +LayoutMap PutOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + Array args; + args.push_back(src_expr); + args.push_back(dst_expr); + Copy copy_op = Copy(args); + LayoutMap out_layout = copy_op->InferLayout(T, level); + return out_layout; +} + +Stmt PutOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + Target target = T.target; + ICHECK(TargetIsSunmmio(target)) << "Put only supports SUNMMIO targets."; + int mesh_x = get_target_mesh(target, 0); + int mesh_y = get_target_mesh(target, 1); + + // check for valid core id + ICHECK(src_core->value >= 0 and src_core->value < mesh_x * mesh_y) + << "Source core id " << src_core->value << " out of range [0, " + << mesh_x * mesh_y << ")"; + ICHECK(dst_core->value >= 0 and dst_core->value < mesh_x * mesh_y) + << "Destination core id " << dst_core->value << " out of range [0, " + << mesh_x * mesh_y << ")"; + + // check for src and dst buffer sizes + PrimExpr src_elements = 1; + for (size_t i = 0; i < src_range.size(); i++) { + src_elements *= src_range[i]->extent; + } + src_elements = analyzer->Simplify(src_elements); + PrimExpr dst_elements = 1; + for (size_t i = 0; i < dst_range.size(); i++) { + dst_elements *= dst_range[i]->extent; + } + dst_elements = analyzer->Simplify(dst_elements); + ICHECK(Downcast(src_elements)->value <= + Downcast(dst_elements)->value) + << "Source buffer size larger than destination buffer size: " + << src_elements << " vs " << dst_elements; + ICHECK(size->value <= Downcast(src_elements)->value) + << "Put size larger than data size: " << size->value << " vs " + << Downcast(src_elements)->value; + + // check for size + PrimExpr broadcast_elements; + if (size->value < 0) { + broadcast_elements = src_elements; + } else { + broadcast_elements = size; + } + ICHECK((Downcast(broadcast_elements)->value) <= + Downcast(src_elements)->value) + << "Put size Larger than source buffer size: " + << (Downcast(broadcast_elements)->value) << " vs " + << Downcast(src_elements)->value; + ICHECK((Downcast(broadcast_elements)->value) <= + Downcast(dst_elements)->value) + << "Put size larger than destination buffer size: " + << (Downcast(broadcast_elements)->value) << " vs " + << Downcast(dst_elements)->value; + + // all checks passed, generate the call + PrimExpr src_addr = src.access_ptr(1, DataType::Handle(), 1, 0, src_elements); + PrimExpr dst_addr = dst.access_ptr(2, DataType::Handle(), 1, 0, dst_elements); + int src_core_x = src_core->value / mesh_y; + int src_core_y = src_core->value % mesh_y; + int dst_core_x = dst_core->value / mesh_y; + int dst_core_y = dst_core->value % mesh_y; + + if (src_core_x == dst_core_x) { + // 1D put via horizontal communication + Array args; + args.push_back(src_addr); + args.push_back(dst_addr); + args.push_back(Downcast(broadcast_elements)); + args.push_back(src_core); + args.push_back(0); // direction: horizontal + for (int j = 0; j < mesh_y; j++) { + if (j != dst_core_y) { + args.push_back( + IntImm(DataType::Int(32), j)); // mask: all cores except dst_core_y + } + } + Stmt put = Evaluate(Call(DataType::Handle(), broadcast_(), args)); + return put; + } else if (src_core_y == dst_core_y) { + // 1D put via vertical communication + Array args; + args.push_back(src_addr); + args.push_back(dst_addr); + args.push_back(Downcast(broadcast_elements)); + args.push_back(src_core); + args.push_back(1); // direction: vertical + for (int i = 0; i < mesh_x; i++) { + if (i != dst_core_x) { + args.push_back( + IntImm(DataType::Int(32), i)); // mask: all cores except dst_core_x + } + } + Stmt put = Evaluate(Call(DataType::Handle(), broadcast_(), args)); + return put; + } else { + Array seq; + // vertical transfer from src core to intermediate core + int intermediate_core_id = src_core_x * mesh_y + dst_core_y; + Array args1; + args1.push_back(src_addr); + args1.push_back(dst_addr); + args1.push_back(Downcast(broadcast_elements)); + args1.push_back(src_core); + args1.push_back(1); // direction: vertical + for (int i = 0; i < mesh_x; i++) { + if (i != dst_core_x) { + args1.push_back( + IntImm(DataType::Int(32), i)); // mask: all cores except dst_core_x + } + } + Stmt put1 = Evaluate(Call(DataType::Handle(), broadcast_(), args1)); + seq.push_back(put1); + // horizontal transfer from intermediate core to dst core + Array args2; + args2.push_back(dst.access_ptr(1, DataType::Handle(), 1, 0, src_elements)); + args2.push_back(dst.access_ptr(2, DataType::Handle(), 1, 0, dst_elements)); + args2.push_back(Downcast(broadcast_elements)); + args2.push_back(IntImm(DataType::Int(32), intermediate_core_id)); + args2.push_back(0); // direction: horizontal + for (int j = 0; j < mesh_y; j++) { + if (j != dst_core_y) { + args2.push_back( + IntImm(DataType::Int(32), j)); // mask: all cores except dst_core_y + } + } + Stmt put2 = Evaluate(Call(DataType::Handle(), broadcast_(), args2)); + seq.push_back(put2); + return SeqStmt::Flatten(seq); + } +} + +TIR_REGISTER_TL_TILE_OP(PutOp, comm_put) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +AllgatherOp::AllgatherOp(Array args) { + ObjectPtr node = tvm::ffi::make_object(); + node->send = args[0]; + node->recv = args[1]; + node->direction = Downcast(args[2])->value; + node->size = Downcast(args[3]); + data_ = std::move(node); +} + +TileOperator AllgatherOpNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return AllgatherOp(op); +} + +// Not yet complete; it will be further refined later +LayoutMap AllgatherOpNode::ComputeLayout(const LayoutInferArgs &T, + InferLevel level, Buffer src, + Buffer dst) const { + if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" && + T.layout_map.count(src)) { + auto src_layout = T.layout_map[src].as().value(); + + PrimExpr src_rep_extent = src_layout->ReplicateExtent(); + + Array fwd; + fwd.push_back(InputPlaceholder(0)); + for (int i = 0; i < static_cast(src->shape.size()); i++) { + fwd.push_back(InputPlaceholder(i + 1)); + } + auto thd = src_layout->ForwardThread(fwd, std::nullopt); + + Fragment dst_layout = + Fragment(dst->shape, {}, thd, src_rep_extent, std::nullopt) + ->CondenseReplicateVar() + ->BindThreadRange(T.thread_bounds); + + if (!T.layout_map.count(dst)) + return {{dst, dst_layout}}; + else { + // Check if computed layout is compatible with existing: the existing one + // must strictly contains the computed layout + auto orig_dst_layout = + T.layout_map.Get(dst).value().as().value(); + ICHECK(dst_layout->InputDim() == orig_dst_layout->InputDim()); + Array indices; + indices.reserve(dst_layout->InputDim()); + arith::Analyzer inner_analyzer; + for (int i = 0; i < dst_layout->InputDim(); ++i) { + auto x = InputPlaceholder(i); + indices.push_back(x); + // should be literal - literal = 0, any analyzer will work + ICHECK(is_zero(inner_analyzer.Simplify( + dst_layout->InputShape()[i] - orig_dst_layout->InputShape()[i]))); + inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i])); + } + + ICHECK(as_const_int(dst_layout->ReplicateExtent())); + ICHECK(as_const_int(src_layout->ReplicateExtent())); + auto dst_rep = *as_const_int(dst_layout->ReplicateExtent()); + auto src_rep = *as_const_int(src_layout->ReplicateExtent()); + if (dst_rep < src_rep || + !ProveFragmentContains(orig_dst_layout, dst_layout, indices, indices, + inner_analyzer)) { + std::ostringstream oss; + oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. " + << src << "\nLHS = " << src_layout->DebugOutput() + << "\nRHS = " << orig_dst_layout->DebugOutput() + << "\nYou may need to use a shared memory to transform the " + "layout"; + throw LayoutConflictException(oss.str()); + } + + if (dst_rep > src_rep) { + return {{dst, dst_layout}}; + } + } + } + return {}; +} + +LayoutMap AllgatherOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + Buffer src_buffer = NormalizeToBufferRegion(send)->buffer; + Buffer recv_buffer = NormalizeToBufferRegion(recv)->buffer; + return ComputeLayout(T, level, src_buffer, recv_buffer); +} + +Stmt AllgatherOpNode::Lower(const LowerArgs &T, + arith::Analyzer *analyzer) const { + Target target = T.target; + ICHECK(TargetIsSunmmio(target)) << "Allgather only supports SUNMMIO targets."; + int mesh_x = get_target_mesh(target, 0); + int mesh_y = get_target_mesh(target, 1); + + Array send_range, recv_range; + auto send_region = NormalizeToBufferRegion(send); + auto recv_region = NormalizeToBufferRegion(recv); + send_range = send_region->region; + recv_range = recv_region->region; + + int recv_num = 1; + if (direction == 0) { // horizontal + recv_num = mesh_y; + } else if (direction == 1) { // vertical + recv_num = mesh_x; + } else if (direction == 2) { // all + recv_num = mesh_x * mesh_y; + } else { + // invalid direction + ICHECK(false) << "Invalid direction value for allgather: " << direction; + } + + PrimExpr send_elements = 1; + for (size_t i = 0; i < send_range.size(); i++) { + send_elements *= send_range[i]->extent; + } + send_elements = analyzer->Simplify(send_elements); + PrimExpr recv_elements = 1; + for (size_t i = 0; i < recv_range.size(); i++) { + recv_elements *= recv_range[i]->extent; + } + recv_elements = analyzer->Simplify(recv_elements); + // check for buffer sizes + ICHECK(Downcast(send_elements)->value * recv_num <= + Downcast(recv_elements)->value) + << "Receive buffer size not enough for allgather: required " + << (Downcast(send_elements)->value * recv_num) << ", but got " + << Downcast(recv_elements)->value; + + // all checks passed, generate the calls + Array bcast_stmts; + + if (direction == 0) { // horizontal + for (int i = 0; i < mesh_x; i++) { + for (size_t j = 0; j < mesh_y; j++) { + Array args; + args.push_back(send); + args.push_back(recv); + args.push_back(size); + args.push_back(IntImm(DataType::Int(32), j) * send_elements); // offset + args.push_back(IntImm(DataType::Int(32), i * mesh_y + j)); // src_core + args.push_back(0); // direction: horizontal + BroadcastOp bcast = BroadcastOp(args); + Stmt bcast_stmt = bcast->Lower(T, analyzer); + bcast_stmts.push_back(bcast_stmt); + } + } + } else if (direction == 1) { // vertical + for (int j = 0; j < mesh_y; j++) { + for (size_t i = 0; i < mesh_x; i++) { + Array args; + args.push_back(send); + args.push_back(recv); + args.push_back(size); + args.push_back(IntImm(DataType::Int(32), i) * send_elements); // offset + args.push_back(IntImm(DataType::Int(32), i * mesh_y + j)); // src_core + args.push_back(1); // direction: vertical + BroadcastOp bcast = BroadcastOp(args); + Stmt bcast_stmt = bcast->Lower(T, analyzer); + bcast_stmts.push_back(bcast_stmt); + } + } + } else if (direction == 2) { // all + // first do horizontal allgather + for (int i = 0; i < mesh_x; i++) { + for (size_t j = 0; j < mesh_y; j++) { + Array args; + args.push_back(send); + args.push_back(recv); + args.push_back(size); + args.push_back(IntImm(DataType::Int(32), i * mesh_y + j) * + send_elements); // offset + args.push_back(IntImm(DataType::Int(32), i * mesh_y + j)); // src_core + args.push_back(0); // direction: horizontal + BroadcastOp bcast = BroadcastOp(args); + Stmt bcast_stmt = bcast->Lower(T, analyzer); + bcast_stmts.push_back(bcast_stmt); + } + } + // then do vertical allgather + Buffer recv_buffer = recv_region->buffer; + int allgather_size = (size->value < 0) + ? Downcast(send_elements)->value * mesh_y + : size->value * mesh_y; + + for (int j = 0; j < mesh_y; j++) { + for (size_t i = 0; i < mesh_x; i++) { + Array args; + args.push_back(recv_buffer.access_ptr( + 1, DataType::Handle(), 1, + IntImm(DataType::Int(32), i * mesh_y) * send_elements, + IntImm(DataType::Int(32), mesh_y) * send_elements)); + args.push_back(recv_buffer.access_ptr( + 2, DataType::Handle(), 1, + IntImm(DataType::Int(32), i * mesh_y) * send_elements, + IntImm(DataType::Int(32), mesh_y) * send_elements)); + args.push_back(IntImm(DataType::Int(32), allgather_size)); // size + args.push_back(IntImm(DataType::Int(32), i * mesh_y + j)); // src_core + args.push_back(1); // direction: vertical + Stmt bcast_stmt = + Evaluate(Call(DataType::Handle(), broadcast_(), args)); + bcast_stmts.push_back(bcast_stmt); + } + } + } + return SeqStmt::Flatten(bcast_stmts); +} + +TIR_REGISTER_TL_TILE_OP(AllgatherOp, comm_allgather) + .set_num_inputs(4) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +AllreduceOp::AllreduceOp(Array args) { + ObjectPtr node = tvm::ffi::make_object(); + node->src = args[0]; + node->dst = args[1]; + node->row_allgather = args[2]; + node->col_allgather = args[3]; + + node->type = Downcast(args[4]); + node->direction = Downcast(args[5])->value; + node->dim = Downcast(args[6]); + node->clear = Downcast(args[7]); + if (args.size() > 8) { + node->dst_copy = args[8]; + } + data_ = std::move(node); +} + +TileOperator AllreduceOpNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + return AllreduceOp(op); +} + +// Not yet complete; it will be further refined later +LayoutMap AllreduceOpNode::ComputeLayout(const LayoutInferArgs &T, + InferLevel level, Buffer src, + Buffer dst, int dim) const { + if (level >= InferLevel::kStrict) + return {}; + + if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" && + T.layout_map.count(src)) { + auto src_layout = T.layout_map[src].as().value(); + + PrimExpr indice_rep_extent = src->shape[dim]; + PrimExpr src_rep_extent = src_layout->ReplicateExtent(); + PrimExpr dest_buffer_rep_extent = indice_rep_extent * src_rep_extent; + + Array fwd; + fwd.push_back(InputPlaceholder(0)); + for (int i = 0; i < static_cast(src->shape.size()); i++) { + if (i == dim) { + ; + } else if (i < dim) { + fwd.push_back(InputPlaceholder(i + 1)); + } else if (i > dim) { + fwd.push_back(InputPlaceholder(i - 1 + 1)); + } + } + auto thd = src_layout->ForwardThread( + fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); + + // Ensure the thread count is divisible by the replicate extent. + // Otherwise, we cannot infer a valid fragment<->fragment layout. + { + arith::Analyzer analyzer; + PrimExpr num_threads = T.thread_bounds->extent; + // Though the dest_buffer_rep_extent will be compressed at + // CondenseReplicateVar, we need to check the divisibility here to avoid + // the issue that the thread count is not divisible by the replicate + // extent. + if (!analyzer.CanProve(FloorMod(num_threads, dest_buffer_rep_extent) == + 0) && + !analyzer.CanProve(FloorMod(dest_buffer_rep_extent, num_threads) == + 0)) { + ICHECK(false) << "ReduceOp fragment layout inference failed: " + "num_threads % replicate_extent != 0. " + << "This mapping requires the block's thread count to be " + "divisible by the " + << "replicate extent. " + << "Try one of: (1) choose a thread block size divisible " + "by replicate_extent; " + << "(2) pick a different reduce dimension or adjust the " + "source fragment layout; " + << "Details: num_threads=" << num_threads + << ", replicate_extent=" << indice_rep_extent + << ", src=" << src << ", dst=" << dst; + } + } + + Fragment dst_layout = + Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt) + ->CondenseReplicateVar() + ->BindThreadRange(T.thread_bounds); + + if (!T.layout_map.count(dst)) + return {{dst, dst_layout}}; + else { + // Check if computed layout is compatible with existing: the existing one + // must strictly contains the computed layout + auto orig_dst_layout = + T.layout_map.Get(dst).value().as().value(); + ICHECK(dst_layout->InputDim() == orig_dst_layout->InputDim()); + Array indices; + indices.reserve(dst_layout->InputDim()); + arith::Analyzer inner_analyzer; + for (int i = 0; i < dst_layout->InputDim(); ++i) { + auto x = InputPlaceholder(i); + indices.push_back(x); + // should be literal - literal = 0, any analyzer will work + ICHECK(is_zero(inner_analyzer.Simplify( + dst_layout->InputShape()[i] - orig_dst_layout->InputShape()[i]))); + inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i])); + } + + ICHECK(as_const_int(dst_layout->ReplicateExtent())); + ICHECK(as_const_int(src_layout->ReplicateExtent())); + auto dst_rep = *as_const_int(dst_layout->ReplicateExtent()); + auto src_rep = *as_const_int(src_layout->ReplicateExtent()); + if (dst_rep < src_rep || + !ProveFragmentContains(orig_dst_layout, dst_layout, indices, indices, + inner_analyzer)) { + std::ostringstream oss; + oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. " + << src << "\nLHS = " << src_layout->DebugOutput() + << "\nRHS = " << orig_dst_layout->DebugOutput() + << "\nYou may need to use a shared memory to transform the " + "layout"; + throw LayoutConflictException(oss.str()); + } + + if (dst_rep > src_rep) { + return {{dst, dst_layout}}; + } + } + } + return {}; +} + +LayoutMap AllreduceOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + LayoutMap lm; + + Array dst_layout_args; + dst_layout_args.push_back(src); + dst_layout_args.push_back(dst); + dst_layout_args.push_back(type); + dst_layout_args.push_back(dim); + dst_layout_args.push_back(clear); + ReduceOp dst_layout_op = ReduceOp(dst_layout_args); + LayoutMap dst_layout_map = dst_layout_op->InferLayout(T, InferLevel::kFree); + for (const auto &kv : dst_layout_map) { + lm.Set(kv.first, kv.second); + } + + if (dst_copy.defined()) { + Array dst_copy_layout_args; + dst_copy_layout_args.push_back(src); + dst_copy_layout_args.push_back(dst_copy); + dst_copy_layout_args.push_back(type); + dst_copy_layout_args.push_back(dim); + dst_copy_layout_args.push_back(clear); + ReduceOp dst_copy_layout_op = ReduceOp(dst_copy_layout_args); + LayoutMap dst_copy_layout_map = + dst_copy_layout_op->InferLayout(T, InferLevel::kFree); + for (const auto &kv : dst_copy_layout_map) { + lm.Set(kv.first, kv.second); + } + } + + Buffer row_allgather_buffer = NormalizeToBufferRegion(row_allgather)->buffer; + LayoutMap row_allgather_layout = + ComputeLayout(T, InferLevel::kFree, NormalizeToBufferRegion(src)->buffer, + row_allgather_buffer, dim->value); + for (const auto &kv : row_allgather_layout) { + lm.Set(kv.first, kv.second); + } + + Buffer col_allgather_buffer = NormalizeToBufferRegion(col_allgather)->buffer; + LayoutMap col_allgather_layout = + ComputeLayout(T, InferLevel::kFree, NormalizeToBufferRegion(src)->buffer, + col_allgather_buffer, dim->value); + for (const auto &kv : col_allgather_layout) { + lm.Set(kv.first, kv.second); + } + + return lm; +} + +Stmt AllreduceOpNode::Lower(const LowerArgs &T, + arith::Analyzer *analyzer) const { + Target target = T.target; + ICHECK(TargetIsSunmmio(target)) << "Allreduce only supports SUNMMIO targets."; + int mesh_x = get_target_mesh(target, 0); + int mesh_y = get_target_mesh(target, 1); + + ICHECK(direction == 0 || direction == 1 || direction == 2) + << "Invalid allreduce direction " << direction + << ", must be 0 (row-wise) or 1 (column-wise) or 2 (all)."; + + Array stmts; + + if (clear.as().value() == true) { + // Local reduce to dst + Array local_reduce_args; + local_reduce_args.push_back(src); + local_reduce_args.push_back(dst); + local_reduce_args.push_back(type); + local_reduce_args.push_back(dim); + local_reduce_args.push_back(IntImm(DataType::Int(32), 1)); // clear = true + ReduceOp local_reduce_op = ReduceOp(local_reduce_args); + Stmt local_reduce_stmt = local_reduce_op->Lower(T, analyzer); + stmts.push_back(local_reduce_stmt); + + if (direction == 0 or direction == 2) { // row-wise + // Allgather dst in rows to row_allgather + Array row_allgather_args; + row_allgather_args.push_back(dst); + row_allgather_args.push_back(row_allgather); + row_allgather_args.push_back( + IntImm(DataType::Int(32), 0)); // direction = horizontal + row_allgather_args.push_back(IntImm(DataType::Int(32), -1)); // size + AllgatherOp row_allgather_op = AllgatherOp(row_allgather_args); + Stmt row_allgather_stmt = row_allgather_op->Lower(T, analyzer); + stmts.push_back(row_allgather_stmt); + + // Local reduce from row_allgather to dst + Array row_reduce_args; + row_reduce_args.push_back(row_allgather); + row_reduce_args.push_back(dst); + row_reduce_args.push_back(type); + row_reduce_args.push_back(IntImm(DataType::Int(32), 0)); // dim + row_reduce_args.push_back(IntImm(DataType::Int(32), 1)); // clear = true + ReduceOp row_reduce_op = ReduceOp(row_reduce_args); + Stmt row_reduce_stmt = row_reduce_op->Lower(T, analyzer); + stmts.push_back(row_reduce_stmt); + } + + if (direction == 1 or direction == 2) { // column-wise + // Allgather dst in columns to col_allgather + Array col_allgather_args; + col_allgather_args.push_back(dst); + col_allgather_args.push_back(col_allgather); + col_allgather_args.push_back( + IntImm(DataType::Int(32), 1)); // direction = vertical + col_allgather_args.push_back(IntImm(DataType::Int(32), -1)); // size + AllgatherOp col_allgather_op = AllgatherOp(col_allgather_args); + Stmt col_allgather_stmt = col_allgather_op->Lower(T, analyzer); + stmts.push_back(col_allgather_stmt); + + // Local reduce from col_allgather to dst + Array col_reduce_args; + col_reduce_args.push_back(col_allgather); + col_reduce_args.push_back(dst); + col_reduce_args.push_back(type); + col_reduce_args.push_back(IntImm(DataType::Int(32), 0)); // dim + col_reduce_args.push_back(IntImm(DataType::Int(32), 1)); // clear = true + ReduceOp col_reduce_op = ReduceOp(col_reduce_args); + Stmt col_reduce_stmt = col_reduce_op->Lower(T, analyzer); + stmts.push_back(col_reduce_stmt); + } + } else { + // Local reduce to dst_copy + Array local_reduce_args; + local_reduce_args.push_back(src); + local_reduce_args.push_back(dst_copy); + local_reduce_args.push_back(type); + local_reduce_args.push_back(dim); + local_reduce_args.push_back(IntImm(DataType::Int(32), 1)); // clear = true + ReduceOp local_reduce_op = ReduceOp(local_reduce_args); + Stmt local_reduce_stmt = local_reduce_op->Lower(T, analyzer); + stmts.push_back(local_reduce_stmt); + + if (direction == 0 or direction == 2) { // row-wise + // Allgather dst in rows to row_allgather + Array row_allgather_args; + row_allgather_args.push_back(dst_copy); + row_allgather_args.push_back(row_allgather); + row_allgather_args.push_back( + IntImm(DataType::Int(32), 0)); // direction = horizontal + row_allgather_args.push_back(IntImm(DataType::Int(32), -1)); // size + AllgatherOp row_allgather_op = AllgatherOp(row_allgather_args); + Stmt row_allgather_stmt = row_allgather_op->Lower(T, analyzer); + stmts.push_back(row_allgather_stmt); + + // Local reduce from row_allgather to dst + Array row_reduce_args; + row_reduce_args.push_back(row_allgather); + row_reduce_args.push_back(direction == 0 ? dst : dst_copy); + row_reduce_args.push_back(type); + row_reduce_args.push_back(IntImm(DataType::Int(32), 0)); // dim + row_reduce_args.push_back(IntImm( + DataType::Int(32), + direction == 0 ? 0 : 1)); // clear = direction == 0 ? false : true + ReduceOp row_reduce_op = ReduceOp(row_reduce_args); + Stmt row_reduce_stmt = row_reduce_op->Lower(T, analyzer); + stmts.push_back(row_reduce_stmt); + } + + if (direction == 1 or direction == 2) { // column-wise + // Allgather dst in columns to col_allgather + Array col_allgather_args; + col_allgather_args.push_back(dst_copy); + col_allgather_args.push_back(col_allgather); + col_allgather_args.push_back( + IntImm(DataType::Int(32), 1)); // direction = vertical + col_allgather_args.push_back(IntImm(DataType::Int(32), -1)); // size + AllgatherOp col_allgather_op = AllgatherOp(col_allgather_args); + Stmt col_allgather_stmt = col_allgather_op->Lower(T, analyzer); + stmts.push_back(col_allgather_stmt); + + // Local reduce from col_allgather to dst + Array col_reduce_args; + col_reduce_args.push_back(col_allgather); + col_reduce_args.push_back(dst); + col_reduce_args.push_back(type); + col_reduce_args.push_back(IntImm(DataType::Int(32), 0)); // dim + col_reduce_args.push_back(IntImm(DataType::Int(32), 0)); // clear = false + ReduceOp col_reduce_op = ReduceOp(col_reduce_args); + Stmt col_reduce_stmt = col_reduce_op->Lower(T, analyzer); + stmts.push_back(col_reduce_stmt); + } + } + + return SeqStmt::Flatten(stmts); +} + +TIR_REGISTER_TL_TILE_OP(AllreduceOp, comm_allreduce) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TVM_FFI_STATIC_INIT_BLOCK() { + PutOpNode::RegisterReflection(); + BroadcastOpNode::RegisterReflection(); + AllgatherOpNode::RegisterReflection(); + AllreduceOpNode::RegisterReflection(); +} + +} // namespace tl +} // namespace tvm diff --git a/src/op/comm.h b/src/op/comm.h new file mode 100644 index 000000000..b79eaa597 --- /dev/null +++ b/src/op/comm.h @@ -0,0 +1,178 @@ +/*! + * \file tl/op/comm.h + * \brief Implementation of Inter-core Communication Operators + */ + +#ifndef TVM_TL_OP_COMM_H_ +#define TVM_TL_OP_COMM_H_ + +#include "operator.h" + +namespace tvm { +namespace tl { + +TVM_DLL const Op &CoreId(); +TVM_DLL const Op &comm_current_core(); +TVM_DLL const Op &comm_is_current_core(); +TVM_DLL const Op &comm_barrier(); +TVM_DLL const Op &comm_fence(); +TVM_DLL const Op &broadcast_(); + +using namespace tir; + +class BroadcastOpNode : public TileOperatorNode { +public: + Buffer src, dst; + Array src_range, dst_range; + PrimExpr src_expr, dst_expr; + IntImm size; + IntImm dst_offset; + IntImm src_core; + int direction; + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.comm_broadcast", BroadcastOpNode, + TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("src", &BroadcastOpNode::src) + .def_ro("dst", &BroadcastOpNode::dst) + .def_ro("src_range", &BroadcastOpNode::src_range) + .def_ro("dst_range", &BroadcastOpNode::dst_range) + .def_ro("src_core", &BroadcastOpNode::src_core) + .def_ro("direction", &BroadcastOpNode::direction) + .def_ro("size", &BroadcastOpNode::size) + .def_ro("dst_offset", &BroadcastOpNode::dst_offset); + } + + TileOperator Clone() const; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; +}; + +class BroadcastOp : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BroadcastOp, TileOperator, + BroadcastOpNode); + TVM_DLL BroadcastOp(Array args); + static const Op &Get(); +}; + +class PutOpNode : public TileOperatorNode { +public: + Buffer src, dst; + Array src_range, dst_range; + PrimExpr src_expr, dst_expr; + IntImm src_core, dst_core; + IntImm size; + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.comm_put", PutOpNode, TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("src", &PutOpNode::src) + .def_ro("dst", &PutOpNode::dst) + .def_ro("src_range", &PutOpNode::src_range) + .def_ro("dst_range", &PutOpNode::dst_range) + .def_ro("src_core", &PutOpNode::src_core) + .def_ro("dst_core", &PutOpNode::dst_core) + .def_ro("size", &PutOpNode::size); + } + + TileOperator Clone() const; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; +}; + +class PutOp : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PutOp, TileOperator, PutOpNode); + TVM_DLL PutOp(Array args); + static const Op &Get(); +}; + +class AllgatherOpNode : public TileOperatorNode { +public: + PrimExpr send, recv; + int direction; + IntImm size; + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.comm_allgather", AllgatherOpNode, + TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("send", &AllgatherOpNode::send) + .def_ro("recv", &AllgatherOpNode::recv) + .def_ro("direction", &AllgatherOpNode::direction) + .def_ro("size", &AllgatherOpNode::size); + } + + TileOperator Clone() const; + LayoutMap ComputeLayout(const LayoutInferArgs &T, InferLevel level, + Buffer src, Buffer dst) const; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; +}; + +class AllgatherOp : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AllgatherOp, TileOperator, + AllgatherOpNode); + TVM_DLL AllgatherOp(Array args); + static const Op &Get(); +}; + +class AllreduceOpNode : public TileOperatorNode { +public: + PrimExpr src, dst; + PrimExpr row_allgather, col_allgather; + PrimExpr dst_copy; + StringImm type; + int direction; + IntImm dim; + IntImm clear; + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.comm_allreduce", AllreduceOpNode, + TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("src", &AllreduceOpNode::src) + .def_ro("dst", &AllreduceOpNode::dst) + .def_ro("row_allgather", &AllreduceOpNode::row_allgather) + .def_ro("col_allgather", &AllreduceOpNode::col_allgather) + .def_ro("type", &AllreduceOpNode::type) + .def_ro("dim", &AllreduceOpNode::dim) + .def_ro("clear", &AllreduceOpNode::clear) + .def_ro("direction", &AllreduceOpNode::direction) + .def_ro("dst_copy", &AllreduceOpNode::dst_copy); + } + + TileOperator Clone() const; + LayoutMap ComputeLayout(const LayoutInferArgs &T, InferLevel level, + Buffer src, Buffer dst, int dim) const; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; +}; + +class AllreduceOp : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AllreduceOp, TileOperator, + AllreduceOpNode); + TVM_DLL AllreduceOp(Array args); + static const Op &Get(); +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_COMM_H_ diff --git a/testing/python/language/test_tilelang_language_comm.py b/testing/python/language/test_tilelang_language_comm.py new file mode 100644 index 000000000..fc63996bf --- /dev/null +++ b/testing/python/language/test_tilelang_language_comm.py @@ -0,0 +1,339 @@ +import pytest + +import tilelang +import tilelang.language as T + +from tilelang import tvm as tvm +from tilelang.utils.target import determine_target + + +@pytest.mark.parametrize("M, N, block_M, block_N, dtype, accum_dtype", [ + (1024, 1024, 128, 128, "float16", "float"), +]) +def test_comm_python_api(M, N, block_M, block_N, dtype, accum_dtype): + func_str = """# from tvm.script import tir as T + +@T.prim_func +def main(A_handle: T.handle): + A = T.match_buffer(A_handle, (1024, 1024), "float16", strides=(1024, 1)) + # with T.block("root"): + bx = T.launch_thread("blockIdx.x", 8) + by = T.launch_thread("blockIdx.y", 8) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + with T.block("tilelang_root"): + T.reads(A[by * 128, bx * 128]) + T.writes() + A_local = T.alloc_buffer((128, 128), scope="local.fragment") + B_local = T.alloc_buffer((128, 128), scope="local.fragment") + C_local = T.alloc_buffer((16, 128, 128), scope="local.fragment") + T.copy(T.region(A[by * 128, bx * 128], 1, 128, 128), T.region(A_local[0, 0], 2, 128, 128), -1, T.bool(False), 0) + T.comm_broadcast(A_local[0:128, 0:128], B_local[0:128, 0:128], -1, 0, 6, 2) + T.comm_put(A_local[0:128, 0:128], B_local[0:128, 0:128], -1, 6, 11) + T.comm_allgather(A_local[0:128, 0:128], C_local[0:16, 0:128, 0:128], 2, -1)""" + + @T.prim_func + def main(A: T.Tensor((M, N), dtype),): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_local = T.alloc_fragment([block_M, block_N], accum_dtype) + B_local = T.alloc_fragment([block_M, block_N], accum_dtype) + C_local = T.alloc_fragment([16, block_M, block_N], accum_dtype) + T.copy(A[by * block_M, bx * block_N], A_local) + + T.comm.broadcast(A_local, B_local, (1, 2), direction="all") + T.comm.put(A_local, B_local, (1, 2), (2, 3)) + T.comm.all_gather(A_local, C_local, direction="all") + + assert main.script() == func_str, "The generated script does not match the expected output." + + +@pytest.mark.parametrize("M, N, block_M, block_N, dtype, accum_dtype", [ + (1024, 1024, 128, 128, "float16", "float"), +]) +def test_comm_broadcast_lower(M, N, block_M, block_N, dtype, accum_dtype): + func_str = """# from tvm.script import ir as I +# from tvm.script import tir as T + +@I.ir_module +class Module: + @T.prim_func + def main(A_handle: T.handle): + T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "mattr": ["device_mesh_nrow_4", "device_mesh_ncol_4"], "mcpu": "sunmmio-a4e", "tag": ""})}) + A = T.match_buffer(A_handle, (1024, 1024), "float16", strides=(1024, 1)) + # with T.block("root"): + bx = T.launch_thread("blockIdx.x", 8) + by = T.launch_thread("blockIdx.y", 8) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + with T.block("tilelang_root"): + T.reads(A[by * 128, bx * 128]) + T.writes() + A_local = T.alloc_buffer((128, 128), scope="local.fragment") + B_local = T.alloc_buffer((128, 128), scope="local.fragment") + for i in T.parallel(128): + for j in T.parallel(32): + for vec in T.vectorized(4): + A_local[i, j * 4 + vec] = T.Cast("float32", A[by * 128 + i, bx * 128 + (j * 4 + vec)]) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 6, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 2, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 6, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 10, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 14, 0)""" + + @T.prim_func + def main(A: T.Tensor((M, N), dtype),): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_local = T.alloc_fragment([block_M, block_N], accum_dtype) + B_local = T.alloc_fragment([block_M, block_N], accum_dtype) + T.copy(A[by * block_M, bx * block_N], A_local) + + T.comm.broadcast(A_local, B_local, (1, 2), direction="all") + + mod = tvm.IRModule({'main': main}) + target = determine_target("Sunmmio", return_object=True) + with tvm.target.Target(target): + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tilelang.transform.LowerTileOp()(mod) + assert mod.script() == func_str, "The generated script does not match the expected output." + + +@pytest.mark.parametrize("M, N, block_M, block_N, dtype, accum_dtype", [ + (1024, 1024, 128, 128, "float16", "float"), +]) +def test_comm_put_lower(M, N, block_M, block_N, dtype, accum_dtype): + func_str = """# from tvm.script import ir as I +# from tvm.script import tir as T + +@I.ir_module +class Module: + @T.prim_func + def main(A_handle: T.handle): + T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "mattr": ["device_mesh_nrow_4", "device_mesh_ncol_4"], "mcpu": "sunmmio-a4e", "tag": ""})}) + A = T.match_buffer(A_handle, (1024, 1024), "float16", strides=(1024, 1)) + # with T.block("root"): + bx = T.launch_thread("blockIdx.x", 8) + by = T.launch_thread("blockIdx.y", 8) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + with T.block("tilelang_root"): + T.reads(A[by * 128, bx * 128]) + T.writes() + A_local = T.alloc_buffer((128, 128), scope="local.fragment") + B_local = T.alloc_buffer((128, 128), scope="local.fragment") + for i in T.parallel(128): + for j in T.parallel(32): + for vec in T.vectorized(4): + A_local[i, j * 4 + vec] = T.Cast("float32", A[by * 128 + i, bx * 128 + (j * 4 + vec)]) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 6, 1, 0, 1, 3) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_local.data, 0, 16384, 2), 16384, 7, 0, 0, 1, 2)""" + + @T.prim_func + def main(A: T.Tensor((M, N), dtype),): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_local = T.alloc_fragment([block_M, block_N], accum_dtype) + B_local = T.alloc_fragment([block_M, block_N], accum_dtype) + T.copy(A[by * block_M, bx * block_N], A_local) + + T.comm.put(A_local, B_local, (1, 2), (2, 3)) + + mod = tvm.IRModule({'main': main}) + target = determine_target("Sunmmio", return_object=True) + with tvm.target.Target(target): + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tilelang.transform.LowerTileOp()(mod) + assert mod.script() == func_str, "The generated script does not match the expected output." + + +@pytest.mark.parametrize("M, N, block_M, block_N, dtype, accum_dtype", [ + (1024, 1024, 128, 128, "float16", "float"), +]) +def test_comm_all_gather_lower(M, N, block_M, block_N, dtype, accum_dtype): + func_str = """# from tvm.script import ir as I +# from tvm.script import tir as T + +@I.ir_module +class Module: + @T.prim_func + def main(A_handle: T.handle): + T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "mattr": ["device_mesh_nrow_4", "device_mesh_ncol_4"], "mcpu": "sunmmio-a4e", "tag": ""})}) + A = T.match_buffer(A_handle, (1024, 1024), "float16", strides=(1024, 1)) + # with T.block("root"): + bx = T.launch_thread("blockIdx.x", 8) + by = T.launch_thread("blockIdx.y", 8) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + with T.block("tilelang_root"): + T.reads(A[by * 128, bx * 128]) + T.writes() + A_local = T.alloc_buffer((128, 128), scope="local.fragment") + C_local = T.alloc_buffer((16, 128, 128), scope="local.fragment") + for i in T.parallel(128): + for j in T.parallel(32): + for vec in T.vectorized(4): + A_local[i, j * 4 + vec] = T.Cast("float32", A[by * 128 + i, bx * 128 + (j * 4 + vec)]) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 16384, 2), 16384, 0, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 16384, 16384, 2), 16384, 1, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 32768, 16384, 2), 16384, 2, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 49152, 16384, 2), 16384, 3, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 16384, 2), 16384, 4, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 81920, 16384, 2), 16384, 5, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 98304, 16384, 2), 16384, 6, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 114688, 16384, 2), 16384, 7, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 16384, 2), 16384, 8, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 147456, 16384, 2), 16384, 9, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 163840, 16384, 2), 16384, 10, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 180224, 16384, 2), 16384, 11, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 16384, 2), 16384, 12, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 212992, 16384, 2), 16384, 13, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 229376, 16384, 2), 16384, 14, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), A_local.data, 0, 16384, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 245760, 16384, 2), 16384, 15, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 2), 65536, 0, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 2), 65536, 4, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 2), 65536, 8, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 2), 65536, 12, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 2), 65536, 1, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 2), 65536, 5, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 2), 65536, 9, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 2), 65536, 13, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 2), 65536, 2, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 2), 65536, 6, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 2), 65536, 10, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 2), 65536, 14, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 65536, 2), 65536, 3, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 65536, 65536, 2), 65536, 7, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 131072, 65536, 2), 65536, 11, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 196608, 65536, 2), 65536, 15, 1)""" + + @T.prim_func + def main(A: T.Tensor((M, N), dtype),): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_local = T.alloc_fragment([block_M, block_N], accum_dtype) + C_local = T.alloc_fragment([16, block_M, block_N], accum_dtype) + T.copy(A[by * block_M, bx * block_N], A_local) + + T.comm.all_gather(A_local, C_local, direction="all") + + mod = tvm.IRModule({'main': main}) + target = determine_target("Sunmmio", return_object=True) + with tvm.target.Target(target): + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tilelang.transform.LowerTileOp()(mod) + assert mod.script() == func_str, "The generated script does not match the expected output." + + +@pytest.mark.parametrize("M, N, block_M, block_N, dtype, accum_dtype", [ + (1024 * 128, 1024 * 128, 1024, 1024, "float16", "float"), +]) +def test_comm_all_reduce_lower(M, N, block_M, block_N, dtype, accum_dtype): + func_str = """# from tvm.script import ir as I +# from tvm.script import tir as T + +@I.ir_module +class Module: + @T.prim_func + def main(A_handle: T.handle): + T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "mattr": ["device_mesh_nrow_4", "device_mesh_ncol_4"], "mcpu": "sunmmio-a4e", "tag": ""})}) + A = T.match_buffer(A_handle, (131072, 131072), "float16", strides=(131072, 1)) + with T.block("root"): + T.reads() + T.writes() + A_local = T.Buffer((8192,), scope="local") + T.block_attr({"layout_map": {A_local: metadata["tl.Fragment"][0]}}) + bx = T.launch_thread("blockIdx.x", 128) + by = T.launch_thread("blockIdx.y", 128) + tx = T.launch_thread("threadIdx.x", 128) + ty = T.launch_thread("threadIdx.y", 1) + tz = T.launch_thread("threadIdx.z", 1) + with T.block("tilelang_root"): + T.reads(A[by * 1024, bx * 1024]) + T.writes() + T.block_attr({"layout_map": {A_local: metadata["tl.Fragment"][0]}}) + A_local = T.alloc_buffer((8192,), data=A_local.data, scope="local") + E_local = T.alloc_buffer((1024,), scope="local") + buffer = T.alloc_buffer((32,), scope="local") + buffer_1 = T.alloc_buffer((32,), scope="local") + buffer_2 = T.alloc_buffer((1024,), scope="local") + workspace = T.alloc_buffer((128,), scope="shared.dyn") + for i in T.parallel(1024): + for j in T.parallel(256): + for vec in T.vectorized(4): + A_local[i * 8 + (j * 4 + vec) // 512 * 4 + (j * 4 + vec) % 4] = T.Cast("float32", A[by * 1024 + i, bx * 1024 + (j * 4 + vec)]) + for i in T.unroll(1024, annotations={"pragma_unroll_explicit": T.bool(False)}): + buffer_2[i] = T.float32(0.0) + for rv in T.unroll(8, annotations={"pragma_unroll_explicit": T.bool(False)}): + buffer_2[i] = buffer_2[i] + A_local[i * 8 + rv % 2 * 4 + rv // 2] + buffer_2[i] = T.call_extern("float32", "tl::AllReduce::run", buffer_2[i], T.tvm_access_ptr(T.type_annotation("float32"), workspace.data, 0, 128, 2)) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 0, 1024, 2), 1024, 0, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 1024, 1024, 2), 1024, 1, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 2048, 1024, 2), 1024, 2, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 3072, 1024, 2), 1024, 3, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 0, 1024, 2), 1024, 4, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 1024, 1024, 2), 1024, 5, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 2048, 1024, 2), 1024, 6, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 3072, 1024, 2), 1024, 7, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 0, 1024, 2), 1024, 8, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 1024, 1024, 2), 1024, 9, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 2048, 1024, 2), 1024, 10, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 3072, 1024, 2), 1024, 11, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 0, 1024, 2), 1024, 12, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 1024, 1024, 2), 1024, 13, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 2048, 1024, 2), 1024, 14, 0) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer.data, 3072, 1024, 2), 1024, 15, 0) + for i in T.unroll(1024, annotations={"pragma_unroll_explicit": T.bool(False)}): + buffer_2[i] = T.float32(0.0) + for rv in T.unroll(4, annotations={"pragma_unroll_explicit": T.bool(False)}): + buffer_2[i] = buffer_2[i] + buffer[rv * 8 + i // 512 * 4 + i % 4] + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 0, 1024, 2), 1024, 0, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 1024, 1024, 2), 1024, 4, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 2048, 1024, 2), 1024, 8, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 3072, 1024, 2), 1024, 12, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 0, 1024, 2), 1024, 1, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 1024, 1024, 2), 1024, 5, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 2048, 1024, 2), 1024, 9, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 3072, 1024, 2), 1024, 13, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 0, 1024, 2), 1024, 2, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 1024, 1024, 2), 1024, 6, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 2048, 1024, 2), 1024, 10, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 3072, 1024, 2), 1024, 14, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 0, 1024, 2), 1024, 3, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 1024, 1024, 2), 1024, 7, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 2048, 1024, 2), 1024, 11, 1) + T.broadcast_(T.tvm_access_ptr(T.type_annotation("float32"), buffer_2.data, 0, 1024, 1), T.tvm_access_ptr(T.type_annotation("float32"), buffer_1.data, 3072, 1024, 2), 1024, 15, 1) + E_local_clear = T.allocate([1024], "float32", "local") + for i in T.unroll(1024, annotations={"pragma_unroll_explicit": T.bool(False)}): + E_local_clear_1 = T.Buffer((1024,), data=E_local_clear, scope="local") + E_local_clear_1[i] = T.float32(0.0) + for rv in T.unroll(4, annotations={"pragma_unroll_explicit": T.bool(False)}): + E_local_clear_1[i] = E_local_clear_1[i] + buffer_1[rv * 8 + i // 512 * 4 + i % 4] + E_local[i] = E_local[i] + E_local_clear_1[i] + +# Metadata omitted. Use show_meta=True in script() method to show it.""" + + @T.prim_func + def main(A: T.Tensor((M, N), dtype),): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_local = T.alloc_fragment([block_M, block_N], accum_dtype) + E_local = T.alloc_fragment([block_M], accum_dtype) + T.copy(A[by * block_M, bx * block_N], A_local) + + T.comm.all_reduce(A_local, E_local, "sum", "all", dim=-1, clear=False) + + mod = tvm.IRModule({'main': main}) + target = determine_target("Sunmmio", return_object=True) + with tvm.target.Target(target): + mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tilelang.transform.LayoutInference()(mod) + mod = tilelang.transform.LowerTileOp()(mod) + assert mod.script() == func_str, "The generated script does not match the expected output." + + +if __name__ == "__main__": + test_comm_python_api(1024, 1024, 128, 128, "float16", "float") + test_comm_broadcast_lower(1024, 1024, 128, 128, "float16", "float") + test_comm_put_lower(1024, 1024, 128, 128, "float16", "float") + test_comm_all_gather_lower(1024, 1024, 128, 128, "float16", "float") + test_comm_all_reduce_lower(1024 * 128, 1024 * 128, 1024, 1024, "float16", "float") diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index f2379eda3..4cd298a2b 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -107,6 +107,7 @@ from .annotations import ( # noqa: F401 use_swizzle, annotate_layout, annotate_safe_value, annotate_l2_hit_ratio, ) +from . import comm # noqa: F401 def import_source(source: str | None = None): diff --git a/tilelang/language/comm.py b/tilelang/language/comm.py new file mode 100644 index 000000000..bf4d5192a --- /dev/null +++ b/tilelang/language/comm.py @@ -0,0 +1,474 @@ +"""Communication intrinsics wrappers for TileLang. + +This module provides small helper functions that prepare arguments and +emit TIR intrinsics for inter-core communication on a target mesh. +""" + +from __future__ import annotations + +from typing import Literal +from collections.abc import Iterable + +from tvm import tir +import tilelang.language as T +from tilelang.utils.language import ( + to_buffer_region,) + +from tilelang.carver.arch.driver import get_sunmmio_device_mesh_config + +DIRECTION_MAP = {"horizontal": 0, "h": 0, "vertical": 1, "v": 1, "all": 2, "a": 2} +REDUCE_TYPE_LIST = ( + "sum", + "abssum", + "max", + "min", + "absmax", + "bitand", + "bitor", + "bitxor", +) + + +def get_target_mesh_shape() -> dict[str, int]: + """Get the target mesh shape as a dictionary with 'x' and 'y' keys.""" + nrow, ncol = get_sunmmio_device_mesh_config() + return {"x": nrow, "y": ncol} + + +def core_tuple_to_id(core_id: tuple[int, int]) -> int: + """Convert 2D (row, col) coordinates on the mesh into a linear core id. + + Parameters + ---------- + core_id : tuple[int, int] + A tuple specifying the (row, col) coordinates of the core on the mesh. + + Returns + ------- + int + The linear core id corresponding to the provided coordinates. + + Notes + ----- + The conversion uses the current target mesh shape obtained via + get_target_mesh_shape(). + """ + mesh_shape = get_target_mesh_shape() + row, col = core_id + assert (0 <= row < mesh_shape["x"]), f"Row {row} out of bounds for mesh shape {mesh_shape}." + assert (0 <= col < mesh_shape["y"]), f"Col {col} out of bounds for mesh shape {mesh_shape}." + core_id_value = row * mesh_shape["y"] + col + return core_id_value + + +def core_id_to_tuple(core_id: tir.Call) -> tuple[int, int]: + """Convert a linear core id into 2D (row, col) coordinates on the mesh. + + Parameters + ---------- + core_id : tir.Call + A linear core identifier (or a TIR expression that yields one). + + Returns + ------- + tuple[int, int] + The (row, col) coordinates corresponding to the linear core id. + + Notes + ----- + The conversion uses the current target mesh shape obtained via + get_target_mesh_shape(). + """ + mesh_shape = get_target_mesh_shape() + core_id_value = core_id + row = core_id_value // mesh_shape["y"] + col = core_id_value % mesh_shape["y"] + return (row, col) + + +def CoreId(core_id: int | tuple[int, int]): + """Convert a core identifier to a linear core ID for the target mesh. + + Parameters + ---------- + core_id : int or tuple[int, int] + Either a linear core id (int) or a 2-tuple (row, col) specifying the + core coordinates on the target mesh. + + Returns + ------- + int + The linear core id mapped into [0, mesh_x * mesh_y). + + Raises + ------ + AssertionError, ValueError + If the provided coordinates are out of bounds or the type is invalid. + """ + mesh_shape = get_target_mesh_shape() + if isinstance(core_id, tuple): + core_id_value = core_tuple_to_id(core_id) + elif isinstance(core_id, int): + core_id_value = core_id + assert (0 <= core_id_value < mesh_shape["x"] * mesh_shape["y"] + ), f"Core ID {core_id_value} out of bounds for mesh shape {mesh_shape}" + else: + raise ValueError("core_id must be either a tuple[int, int] or an int.") + return tir.call_intrin("handle", tir.op.Op.get("tl.CoreId"), core_id_value) + + +def current_core(): + """Get the current core's identifier. + + Returns + ------- + tir.Call + The TIR intrinsic call handle for `tl.comm_current_core`. + + Examples + -------- + >>> current_core() + """ + return tir.call_intrin("handle", tir.op.Op.get("tl.comm_current_core")) + + +def broadcast( + src: T.Buffer, + dst: T.Buffer, + src_core: tuple[int, int], + direction: Literal["horizontal", "h", "vertical", "v", "all", "a"] = "all", + size: int = -1, +): + """Broadcast data from a source buffer on a specific source core to a destination buffer + on all cores in the specified direction by emitting the TIR intrinsic tl.tileop.comm_broadcast. + Parameters + ---------- + src : T.Buffer + Source buffer containing data to broadcast. + dst : T.Buffer + Destination buffer to receive the broadcasted data. + src_core : tuple[int, int] + (row, col) coordinates of the source core on the target mesh. + direction : Literal["horizontal", "h", "vertical", "v", "all", "a"] + Direction of broadcast: "horizontal" (or "h") for row-wise, "vertical" (or "v") for column-wise, + and "all" (or "a") for all cores. + size : int + Number of elements to broadcast. If -1, the entire source buffer is used. + Returns + ------- + tir.Call + The TIR intrinsic call handle for `tl.tileop.comm_broadcast`. + Examples + -------- + >>> broadcast(A, B, (1, 2), direction="horizontal") + """ + assert ( + src.dtype == dst.dtype + ), f"Source and destination buffer dtypes must match for broadcast. Got {src.dtype} vs {dst.dtype}." + if len(src.shape) != len(dst.shape): + raise ValueError( + "Source and destination buffer must have the same number of dimensions for broadcast.") + for i in range(len(src.shape)): + assert ( + src.shape[i] == dst.shape[i] or src.shape[i] == 1 or dst.shape[i] == 1 + ), f"Source buffer shape and destination buffer shape must match for broadcast. Got {src.shape} vs {dst.shape}." + + mesh_shape = get_target_mesh_shape() + assert (isinstance(src_core, tuple) and + len(src_core) == 2), "src_core must be a tuple of (row, col)." + assert (0 <= src_core[0] < mesh_shape["x"] + ), f"src_core row {src_core[0]} out of bounds for mesh shape {mesh_shape}." + assert (0 <= src_core[1] < mesh_shape["y"] + ), f"src_core col {src_core[1]} out of bounds for mesh shape {mesh_shape}." + + src_elements = 1 + for dim in src.shape: + src_elements *= dim + assert isinstance(size, int) and size >= -1, "size must be an integer >= -1." + assert (size <= src_elements), f"size {size} exceeds source buffer size {src_elements}." + + assert direction.lower() in DIRECTION_MAP, f"Invalid direction string: {direction}" + + src_region = to_buffer_region(src) + dst_region = to_buffer_region(dst) + src_core_id = core_tuple_to_id(src_core) + dst_offset = 0 # Always 0 for now + + args = ( + src_region, + dst_region, + size, + dst_offset, + src_core_id, + DIRECTION_MAP[direction.lower()], + ) + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.comm_broadcast"), *args) + + +def put( + src: T.Buffer, + dst: T.Buffer, + src_core: tuple[int, int], + dst_core: tuple[int, int], + size: int = -1, +): + """Put data from a source buffer on a specific source core to a destination buffer on a specific destination core + by emitting the TIR intrinsic tl.tileop.comm_put. + Parameters + ---------- + src : T.Buffer + Source buffer containing data to put. + dst : T.Buffer + Destination buffer to receive the data. + src_core : tuple[int, int] + (row, col) coordinates of the source core on the target mesh. + dst_core : tuple[int, int] + (row, col) coordinates of the destination core on the target mesh. + size : int + Number of elements to put. If -1, the entire source buffer is used. + Returns + ------- + tir.Call + The TIR intrinsic call handle for `tl.tileop.comm_put`. + Examples + -------- + >>> put(A, B, (1, 2), (2, 3)) + """ + assert ( + src.dtype == dst.dtype + ), f"Source and destination buffer dtypes must match for put. Got {src.dtype} vs {dst.dtype}." + if len(src.shape) != len(dst.shape): + raise ValueError( + "Source and destination buffer must have the same number of dimensions for put.") + for i in range(len(src.shape)): + assert ( + src.shape[i] == dst.shape[i] or src.shape[i] == 1 or dst.shape[i] == 1 + ), f"Source buffer shape and destination buffer shape must be compatible for put. Got {src.shape} vs {dst.shape}." + + mesh_shape = get_target_mesh_shape() + assert (isinstance(src_core, tuple) and + len(src_core) == 2), "src_core must be a tuple of (row, col)." + assert (0 <= src_core[0] < mesh_shape["x"] + ), f"src_core row {src_core[0]} out of bounds for mesh shape {mesh_shape}." + assert (0 <= src_core[1] < mesh_shape["y"] + ), f"src_core col {src_core[1]} out of bounds for mesh shape {mesh_shape}." + assert (isinstance(dst_core, tuple) and + len(dst_core) == 2), "dst_core must be a tuple of (row, col)." + assert (0 <= dst_core[0] < mesh_shape["x"] + ), f"dst_core row {dst_core[0]} out of bounds for mesh shape {mesh_shape}." + assert (0 <= dst_core[1] < mesh_shape["y"] + ), f"dst_core col {dst_core[1]} out of bounds for mesh shape {mesh_shape}." + src_elements = 1 + for dim in src.shape: + src_elements *= dim + assert isinstance(size, int) and size >= -1, "size must be an integer >= -1." + assert (size <= src_elements), f"size {size} exceeds source buffer size {src_elements}." + + src_region = to_buffer_region(src) + dst_region = to_buffer_region(dst) + src_core_id = core_tuple_to_id(src_core) + dst_core_id = core_tuple_to_id(dst_core) + args = (src_region, dst_region, size, src_core_id, dst_core_id) + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.comm_put"), *args) + + +def all_gather( + send_buffer: T.Buffer, + recv_buffer: T.Buffer, + direction: Literal["horizontal", "h", "vertical", "v", "all", "a"] = "all", + size: int = -1, +): + """Perform an all-gather operation from a send buffer to a receive buffer + by emitting the TIR intrinsic tl.tileop.comm_allgather. + Parameters + ---------- + send_buffer : T.Buffer + Buffer containing data to send. + recv_buffer : T.Buffer + Buffer to receive gathered data. + direction : Literal["horizontal", "h", "vertical", "v", "all", "a"] + Direction of all-gather: "horizontal" (or "h") for row-wise, "vertical" (or "v") for column-wise, + and "all" (or "a") for all cores. + size : int + Number of elements to send from each core. If -1, the entire send buffer is used. + Returns + ------- + tir.Call + The TIR intrinsic call handle for `tl.tileop.comm_allgather`. + Examples + -------- + >>> all_gather(A_local, C_local, direction="horizontal") + """ + assert direction.lower() in DIRECTION_MAP, f"Invalid direction string: {direction}" + + assert ( + send_buffer.dtype == recv_buffer.dtype + ), f"Source and destination buffer dtypes must match for all_gather. Got {send_buffer.dtype} vs {recv_buffer.dtype}." + mesh_shape = get_target_mesh_shape() + + recv_num = 1 + if direction.lower() in ["horizontal", "h"]: + recv_num = mesh_shape["y"] + elif direction.lower() in ["vertical", "v"]: + recv_num = mesh_shape["x"] + elif direction.lower() in ["all", "a"]: + recv_num = mesh_shape["x"] * mesh_shape["y"] + + expected_recv_shape = [recv_num] + list(send_buffer.shape) + assert ( + list(recv_buffer.shape) == expected_recv_shape + ), f"Receive buffer shape must be {expected_recv_shape} to hold gathered data from {recv_num} cores, but got {recv_buffer.shape}." + + assert isinstance(size, int) and size >= -1, "size must be an integer >= -1." + send_elements = 1 + for dim in send_buffer.shape: + send_elements *= dim + assert (size <= send_elements), f"size {size} exceeds send buffer size {send_elements}." + + send_buffer_region = to_buffer_region(send_buffer) + recv_buffer_region = to_buffer_region(recv_buffer) + + args = ( + send_buffer_region, + recv_buffer_region, + DIRECTION_MAP[direction.lower()], + size, + ) + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.comm_allgather"), *args) + + +def all_reduce( + buffer: T.Buffer, + out: T.Buffer, + reduce_type: str, + direction: Literal["horizontal", "h", "vertical", "v", "all", "a"], + dim: int = -1, + clear: bool = True, +): + """Perform an all-reduce operation on a buffer and store the result in an output buffer + by emitting the TIR intrinsic tl.tileop.comm_allreduce. + Parameters + ---------- + buffer : T.Buffer + Input buffer containing data to reduce. + out : T.Buffer + Output buffer to store the reduced result. + reduce_type : str + Type of reduction operation (e.g., "sum", "max", etc.). + direction : Literal["horizontal", "h", "vertical", "v", "all", "a"] + Direction of all-reduce: "horizontal" (or "h") for row-wise, "vertical" (or "v") for column-wise, + and "all" (or "a") for all cores. + dim : int + Dimension along which to perform the reduction. Default is -1 (last dimension). + clear : bool + Whether to clear the output buffer before reduction. Default is True. + Returns + ------- + tir.Call + The TIR intrinsic call handle for `tl.tileop.comm_allreduce`. + Examples + -------- + >>> all_reduce(A_local, E_local, "sum", "all", dim=-1, clear=False) + """ + assert (isinstance(dim, int) and dim >= -1 and dim < len( + buffer.shape)), f"dim {dim} out of bounds for buffer with {len(buffer.shape)} dimensions." + if dim == -1: + dim = len(buffer.shape) - 1 + + expected_shapes = [ + buffer.shape[:dim] + buffer.shape[dim + 1:], + buffer.shape[:dim] + [1] + buffer.shape[dim + 1:], + ] + if list(out.shape) not in expected_shapes: + expected_shapes_str = " or ".join(map(str, expected_shapes)) + raise ValueError( + f"Invalid reduce output shape, buffer shape is {buffer.shape}, dim is {dim}, " + f"output shape is {out.shape}, expected shapes are {expected_shapes_str}") + + reduce_type = reduce_type.lower() + assert (reduce_type in REDUCE_TYPE_LIST + ), f"Reduction op must be one of {REDUCE_TYPE_LIST}, but got {reduce_type}." + + assert direction.lower() in DIRECTION_MAP, f"Invalid direction string: {direction}" + assert clear in [True, False], "clear must be a boolean value." + + mesh_shape = get_target_mesh_shape() + + # Create temporary buffers for row and column allgather results + row_allgather = T.alloc_fragment(list([mesh_shape["x"]] + out.shape), out.dtype) + col_allgather = T.alloc_fragment(list([mesh_shape["y"]] + out.shape), out.dtype) + + buffer_region = to_buffer_region(buffer) + out_region = to_buffer_region(out) + row_allgather_region = to_buffer_region(row_allgather) + col_allgather_region = to_buffer_region(col_allgather) + + args = ( + buffer_region, + out_region, + row_allgather_region, + col_allgather_region, + reduce_type, + DIRECTION_MAP[direction.lower()], + dim, + clear, + ) + + # If not clearing, allocate an output copy buffer to hold intermediate results + if not clear: + out_copy = T.alloc_fragment(list(out.shape), out.dtype) + out_copy_region = to_buffer_region(out_copy) + args = ( + buffer_region, + out_region, + row_allgather_region, + col_allgather_region, + reduce_type, + DIRECTION_MAP[direction.lower()], + dim, + clear, + out_copy_region, + ) + + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.comm_allreduce"), *args) + + +def barrier(group: Iterable[tuple[int, int]] | None = None): + """Insert a synchronization barrier among a group of cores. + + Parameters + ---------- + group : iterable of tuple[int, int] | None + Optional set of core coordinates to synchronize. If omitted, the + runtime's default participant set is used. + + Returns + ------- + tir.Call + The TIR intrinsic call handle for `tl.comm_barrier`. + + Examples + -------- + >>> barrier() + >>> barrier(group=[(0,0),(0,1)]) + """ + if group is None: + return tir.call_intrin("handle", tir.op.Op.get("tl.comm_barrier")) + else: + group = [core_tuple_to_id(core_id) for core_id in group] + return tir.call_intrin("handle", tir.op.Op.get("tl.comm_barrier"), *group) + + +def fence(): + """Emit a memory/communication fence intrinsic. + + Returns + ------- + tir.Call + The TIR intrinsic call handle for `tl.comm_fence`. + + Examples + -------- + >>> fence() + """ + return tir.call_intrin("handle", tir.op.Op.get("tl.comm_fence"))