Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,7 @@ TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(dma_load).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(dma_store).set_num_inputs(-1).set_attr<TCallEffectKind>(
TIR_DEFINE_TL_BUILTIN(dma_copy).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(create_tma_descriptor)
Expand Down
50 changes: 23 additions & 27 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ TVM_DLL const Op &get_mbarrier();
TVM_DLL const Op &tma_load();

/*!
* \brief Perform a DMA load operation from source memory to destination memory.
* \brief Perform a DMA copy operation from source memory to destination memory.
*
* This function describes a DMA-based tensor copy with explicit shape, layout,
* memory scope. It is typically used to lower a high-level
Expand All @@ -211,38 +211,42 @@ TVM_DLL const Op &tma_load();
* A[32:64, 128:192, 0:256]
* then:
* src_rank = 3
* src_shape = [128, 256, 512]
* src_region_shape = [32, 64, 256]
* coord = [32, 128, 0]
*
* \param src_scope
* Memory scope of the source tensor.
* Examples: "global", "shared.asram", "shared.wsram", "shared.rsram".
*
* \param dst_scope
* Memory scope of the destination tensor.
* Examples: "global", "shared.asram", "shared.wsram", "shared.rsram".
*
* \param data_type
* Element data type of the tensor (e.g. float32, float16).
*
* \param src_rank
* Rank (number of dimensions) of the source tensor.
*
* \param src_shape
* Logical shape of the source tensor.
* For example, Tensor(128, 256, 512) -> [128, 256, 512].
* \param src_region_shape
* Logical shape of the source buffer region.
* For example, A[32:64, 128:192, 0:256] -> [32, 64, 256].
*
* \param src_input_size
* Input shape of the source layout, retrievable via Layout::getInputShape().
* Input shape of the source layout, retrievable via LayoutNode::InputShape().
* For a row-major 3D tensor, this is identical to src_shape.
*
* \param src_forward
* Forward index mapping of the source layout, retrievable via
* Layout::GetForwardIndex().
* LayoutNode::GetForwardIndex().
* For a row-major layout of Tensor(128, 256, 512),
* this is [256 * 512, 512, 1].
*
* \param src_scope
* Memory scope of the source tensor.
* Examples: "global", "shared.asram", "shared.wsram", "shared.rsram".
*
* \param dst_rank
* Rank (number of dimensions) of the destination tensor.
*
* \param dst_shape
* Logical shape of the destination tensor.
* \param dst_region_shape
* Logical shape of the destination buffer region.
*
* \param dst_input_size
* Input shape of the destination layout, retrievable via
Expand All @@ -252,32 +256,24 @@ TVM_DLL const Op &tma_load();
* Forward index mapping of the destination layout, retrievable via
* Layout::GetForwardIndex().
*
* \param dst_scope
* Memory scope of the destination tensor.
* Examples: "global", "shared.asram", "shared.wsram", "shared.rsram".
*
* \param src_addr
* Base address of the source tensor in memory .
*
* \param coord
* \param src_coord
* Coordinate offset specifying the starting point of the copy in the source
* tensor. Its length must equal src_rank.
*
* \param dst_addr
* Base address of the destination tensor in memory .
*
* \param dst_coord
* Coordinate offset specifying the starting point of the copy in the
* destination tensor. Its length must equal dst_rank.
*
* \note
* Out-of-bound fill policies are currently not supported.
*/
TVM_DLL const Op &dma_load();

/*!
* \brief Perform a DMA store operation from source memory to destination
* memory. see dma_load for details.
*
*
*/
TVM_DLL const Op &dma_store();
TVM_DLL const Op &dma_copy();

/*!
* \brief tvm intrinsics for loading image from global tensor to columns in
Expand Down
222 changes: 222 additions & 0 deletions src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,28 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
}
return {};
}

if (copy_inst == CopyInst::kDMACopy) {
// for dma copy, we can directly apply the blockwise_zz_layout
const auto f =
ffi::Function::GetGlobal("tl.layout.make_blockwise_zz_layout");
auto result = Map<Buffer, Layout>();

if (level == InferLevel::kFree && !T.layout_map.count(src)) {
if (src.scope() != "global") {
auto layout = Downcast<Layout>((*f)(src));
result.Set(src, layout);
}
}

if (level == InferLevel::kFree && !T.layout_map.count(dst)) {
if (dst.scope() != "global") {
auto layout = Downcast<Layout>((*f)(dst));
result.Set(dst, layout);
}
}
return result;
}
// for LDSM/STSM, the layout was deduced from register layout
// so we can directly apply the layout of normal copy
// Use parallel op to infer the layout
Expand All @@ -573,6 +595,67 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
}
return par_op_->InferLayout(T, level);
}

/**
* @brief Determine whether this CopyNode can be lowered to a DMA Load
* instruction.
*
* The function returns true when all of the following hold:
* - the target architecture advertises DMA support;
* - the source buffer and the destination buffer are legal;
* - the source and destination have the same element data type.
*
* If the source and destination dtypes differ, a warning is logged and the
* function returns false (the caller is expected to fall back to a normal
* copy).
*
*
* @param target The compilation target to query for dma load support.
* @return true if the copy can be implemented as a DMA Load; false
* otherwise.
*/
bool CopyNode::CheckDMACopy(Target target, arith::Analyzer *analyzer,
bool check_last_dim) const {
// 1. arch must support Sunmmio
if (!TargetIsSunmmio(target))
return false;

// 2. src and dst must be legal
bool scope_check = false;
// 2.1 DRAM -> RSRAM
if (src.scope() == "global" && dst.scope() == "shared.rsram")
scope_check = true;
// 2.2 DRAM -> WSRAM
if (src.scope() == "global" && dst.scope() == "shared.wsram")
scope_check = true;
// 2.3 DRAM -> ASRAM
if (src.scope() == "global" && dst.scope() == "shared.asram")
scope_check = true;
// 2.4 RSRAM -> WSRAM
if (src.scope() == "shared.rsram" && dst.scope() == "shared.wsram")
scope_check = true;
// 2.5 RSRAM -> ASRAM
if (src.scope() == "shared.rsram" && dst.scope() == "shared.asram")
scope_check = true;
// 2.6 RSRAM <-> RSRAM
if (src.scope() == "shared.rsram" && dst.scope() == "shared.rsram")
scope_check = true;
// 2.7 RSRAM -> DRAM
if (src.scope() == "shared.rsram" && dst.scope() == "global")
scope_check = true;
if (!scope_check)
return false;

// 3. src and dst must have the same dtype
if (src->dtype != dst->dtype) {
LOG(WARNING) << "src and dst must have the same dtype for dma copy "
<< src->name << " vs. " << dst->name << " dtype " << src->dtype
<< " vs. " << dst->dtype << " will be fallback to normal copy";
return false;
}
return true;
}

/**
* @brief Determine whether this CopyNode can be lowered to a Bulk Load (TMA)
* instruction.
Expand Down Expand Up @@ -848,6 +931,12 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
return CopyInst::kTMemLoad;
} else if (CheckTMemStore(target)) {
return CopyInst::kTMemStore;
} else if (TargetIsSunmmio(target)) {
auto is_copy = CheckDMACopy(target, analyzer);
if (is_copy)
return CopyInst::kDMACopy;
ICHECK(0) << "Unsupported copy from " << src.scope() << " to "
<< dst.scope() << " of Sunmmio target.";
} else {
return CopyInst::kNormal;
}
Expand All @@ -860,6 +949,7 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
* determined copy instruction type:
* - Bulk Load/Store: Uses Tensor Memory Accelerator (TMA) instructions
* - LDSM/STSM: Uses matrix load/store instructions for tensor cores
* - DMA Load/Store: Sunmmio specified instructions for copy
* - Normal: Uses standard load/store operations with loop transformations
* \param T LowerArgs containing target and layout map.
* \param analyzer Arithmetic analyzer for simplification.
Expand All @@ -874,6 +964,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
pass_ctx->GetConfig<Bool>(kDisableTMALower, Bool(false)).value();
auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma,
T.layout_map, analyzer);

if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) {
auto tmem_copy = LowerTmemCopy(T, analyzer);
ICHECK(tmem_copy.defined()) << "Failed to lower tensor memory copy";
Expand All @@ -894,11 +985,141 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return ldsm_copy;
} else if (copy_inst == CopyInst::kNormal) {
return LowerNormalCopy(T, analyzer);
} else if (copy_inst == CopyInst::kDMACopy) {
auto dma_copy = LowerDMACopy(T, analyzer, copy_inst);
ICHECK(dma_copy.defined()) << "Failed to lower dma load/store";
return dma_copy;
} else {
LOG(FATAL) << "Unsupported copy inst " << static_cast<int>(copy_inst);
}
}

/**
* @brief Lower a Copy operator to a DMA transfer.
*
* Lowers the copy to an optimized DMA load or store when the target and buffer
* layouts permit.
*
* @param T LowerArgs containing target information, thread/bounds variables,
* and layout/ buffer remap information
* construction.
* @param analyzer Analyzer used to prove shapes/contiguity/equality
* constraints.
* @param copy_inst Indicates whether to emit a DMA load or DMA store. Must be
* CopyInst::kDMALoad or kDMAStore.
* @return Stmt A TIR statement performing the DMA copy.
*/
Stmt CopyNode::LowerDMACopy(const LowerArgs &T, arith::Analyzer *analyzer,
CopyInst copy_inst) const {
ICHECK(copy_inst == CopyInst::kDMACopy)
<< "Invalid copy inst " << static_cast<int>(copy_inst);

Array<PrimExpr> args;
// \param src_scope
auto src_scope = StringImm(src.scope());
args.push_back(src_scope);
// \param dst_scope
auto dst_scope = StringImm(dst.scope());
args.push_back(dst_scope);
// \param data_type
args.push_back(to_CUtensorMapDataType(src->dtype));

// \param src_rank
args.push_back(static_cast<int>(src->shape.size()));
// \param src_region_shape
for (auto r : src_range) {
args.push_back(r->extent);
}
// \param src_input_size & \param src_forward
if (src.scope() == "global") {
ICHECK(T.global_layout_map.count(src))
<< "Layout of buffer " << src << " not found.";
auto layout = T.global_layout_map.at(src);
for (auto s : layout->InputShape()) {
args.push_back(s);
}
for (auto s : layout->GetForwardIndex()) {
args.push_back(s);
}
} else {
ICHECK(T.layout_map.count(src))
<< "Layout of buffer " << src << " not found.";
auto layout = T.layout_map.at(src);
for (auto s : layout->InputShape()) {
args.push_back(s);
}
for (auto s : layout->GetForwardIndex()) {
args.push_back(s);
}
}

// \param dst_rank
args.push_back(static_cast<int>(dst->shape.size()));
// \param dst_region_shape
for (auto r : dst_range) {
args.push_back(r->extent);
}
// \param dst_input_size & \param dst_forward
if (dst.scope() == "global") {
ICHECK(T.global_layout_map.count(dst))
<< "Layout of buffer " << dst << " not found.";
auto layout = T.global_layout_map.at(dst);
for (auto s : layout->InputShape()) {
args.push_back(s);
}
for (auto s : layout->GetForwardIndex()) {
args.push_back(s);
}
} else {
ICHECK(T.layout_map.count(dst))
<< "Layout of buffer " << dst << " not found.";
auto layout = T.layout_map.at(dst);
for (auto s : layout->InputShape()) {
args.push_back(s);
}
for (auto s : layout->GetForwardIndex()) {
args.push_back(s);
}
}

// \param src_addr
if (src.scope() == "global") {
args.push_back(src->data);
} else {
PrimExpr total_elements = 1;
for (auto e : src->shape) {
total_elements *= e;
}
auto addr = src.access_ptr(1, DataType::Handle(), 1, 0, total_elements);
args.push_back(addr);
}
// \param src_coord
for (auto r : src_range) {
args.push_back(r->min);
}
// \param dst_addr
if (dst.scope() == "global") {
args.push_back(dst->data);
} else {
PrimExpr total_elements = 1;
for (auto e : dst->shape) {
total_elements *= e;
}
auto addr = dst.access_ptr(2, DataType::Handle(), 1, 0, total_elements);
args.push_back(addr);
}
// \param dst_coord
for (auto r : dst_range) {
args.push_back(r->min);
}

auto op = dma_copy();
Stmt dma_copy;
dma_copy = Evaluate(Call(DataType::Handle(), op, args));

return dma_copy;
}

/**
* @brief Lower the copy operator using the generic (non-specialized) path.
*
Expand Down Expand Up @@ -1763,6 +1984,7 @@ Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer,
tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
return tma_copy;
}

/*!
* \brief Encode the TMA descriptor into an array of PrimExpr.
* This function serializes the TMA descriptor fields into a format suitable for
Expand Down
Loading