Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,7 @@ TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(dma_load).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(dma_store).set_num_inputs(-1).set_attr<TCallEffectKind>(
TIR_DEFINE_TL_BUILTIN(dma_copy).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(create_tma_descriptor)
Expand Down
95 changes: 9 additions & 86 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,95 +189,18 @@ TVM_DLL const Op &get_mbarrier();
TVM_DLL const Op &tma_load();

/*!
* \brief Perform a DMA load operation from source memory to destination memory.
*
* This function describes a DMA-based tensor copy with explicit shape, layout,
* memory scope. It is typically used to lower a high-level
* tensor copy into a hardware-specific DMA instruction.
*
* The source and destination tensors are described in terms of:
* - data type
* - rank and logical shape
* - layout (input shape + forward index), The T.Layout type is ObjectRef,
* which is not suitable for backend parsing, so it's two members are extracted:
* input shape and forward index, which are both Array<PrimExpr>
* - memory scope
*
* A sub-region of the source tensor can be copied by specifying the coordinate
* offset (`coord`) relative to the source base address.
*
* Example:
* For a 3D tensor A: Tensor(128, 256, 512), copying
* A[32:64, 128:192, 0:256]
* then:
* src_rank = 3
* src_shape = [128, 256, 512]
* coord = [32, 128, 0]
*
* \param data_type
* Element data type of the tensor (e.g. float32, float16).
*
* \param src_rank
* Rank (number of dimensions) of the source tensor.
*
* \param src_shape
* Logical shape of the source tensor.
* For example, Tensor(128, 256, 512) -> [128, 256, 512].
*
* \param src_input_size
* Input shape of the source layout, retrievable via Layout::getInputShape().
* For a row-major 3D tensor, this is identical to src_shape.
*
* \param src_forward
* Forward index mapping of the source layout, retrievable via
* Layout::GetForwardIndex().
* For a row-major layout of Tensor(128, 256, 512),
* this is [256 * 512, 512, 1].
*
* \param src_scope
* Memory scope of the source tensor.
* Examples: "global", "shared.asram", "shared.wsram", "shared.rsram".
*
* \param dst_rank
* Rank (number of dimensions) of the destination tensor.
*
* \param dst_shape
* Logical shape of the destination tensor.
*
* \param dst_input_size
* Input shape of the destination layout, retrievable via
* Layout::getInputShape().
*
* \param dst_forward
* Forward index mapping of the destination layout, retrievable via
* Layout::GetForwardIndex().
*
* \param dst_scope
* Memory scope of the destination tensor.
* Examples: "global", "shared.asram", "shared.wsram", "shared.rsram".
*
* \param src_addr
* Base address of the source tensor in memory .
*
* \param coord
* Coordinate offset specifying the starting point of the copy in the source
* tensor. Its length must equal src_rank.
*
* \param dst_addr
* Base address of the destination tensor in memory .
*
* \note
* Out-of-bound fill policies are currently not supported.
*/
TVM_DLL const Op &dma_load();

/*!
* \brief Perform a DMA store operation from source memory to destination
* memory. see dma_load for details.
* \brief Perform a DMA copy operation preserving full buffer region semantics.
*
* This intrinsic encodes a high-level copy between two buffer regions as
* tl.dma_copy(src_region, dst_region), where each argument is a
* tl.tileop.region Call carrying the buffer, access mask, and per-axis
* extents. It is emitted by the SUNMMIO lowering path of CopyNode and
* consumed by later target-specific codegen passes.
*
* \param src_region A tl.tileop.region PrimExpr describing the source.
* \param dst_region A tl.tileop.region PrimExpr describing the destination.
*/
TVM_DLL const Op &dma_store();
TVM_DLL const Op &dma_copy();

/*!
* \brief tvm intrinsics for loading image from global tensor to columns in
Expand Down
116 changes: 116 additions & 0 deletions src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,30 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
}
return {};
}

// Sunmmio DMA Layout Inference
if (copy_inst == CopyInst::kSunmmioDMACopy) {
// for dma copy, we can directly apply the blockwise_zz_layout
const auto f =
ffi::Function::GetGlobal("tl.layout.make_blockwise_zz_layout");
auto result = Map<Buffer, Layout>();

if (level == InferLevel::kFree && !T.layout_map.count(src)) {
if (src.scope() != "global") {
auto layout = Downcast<Layout>((*f)(src));
result.Set(src, layout);
}
}

if (level == InferLevel::kFree && !T.layout_map.count(dst)) {
if (dst.scope() != "global") {
auto layout = Downcast<Layout>((*f)(dst));
result.Set(dst, layout);
}
}
return result;
}

// for LDSM/STSM, the layout was deduced from register layout
// so we can directly apply the layout of normal copy
// Use parallel op to infer the layout
Expand Down Expand Up @@ -804,6 +828,65 @@ bool CopyNode::CheckTMemStore(Target target) const {
dst.scope() == "shared.tmem";
}

/**
* @brief Determine whether this CopyNode can be lowered to a DMA Copy
* Intrinsic for Sunmmio target.
*
* The function returns true when all of the following hold:
* - the target architecture advertises DMA support;
* - the source buffer and the destination buffer are legal;
* - the source and destination have the same element data type.
*
* If the source and destination dtypes differ, a warning is logged and the
* function returns false (the caller is expected to fall back to a normal
* copy).
*
*
* @param target The compilation target to query for dma copy support.
* @return true if the copy can be implemented as a DMA Copy; false
* otherwise.
*/
bool CopyNode::CheckSunmmioDMACopy(Target target) const {
// 1. arch must support Sunmmio
if (!TargetIsSunmmio(target))
return false;

// 2. src and dst must be legal
bool scope_check = false;
// 2.1 DRAM -> RSRAM
if (src.scope() == "global" && dst.scope() == "shared.rsram")
scope_check = true;
// 2.2 DRAM -> WSRAM
if (src.scope() == "global" && dst.scope() == "shared.wsram")
scope_check = true;
// 2.3 DRAM -> ASRAM
if (src.scope() == "global" && dst.scope() == "shared.asram")
scope_check = true;
// 2.4 RSRAM -> WSRAM
if (src.scope() == "shared.rsram" && dst.scope() == "shared.wsram")
scope_check = true;
// 2.5 RSRAM -> ASRAM
if (src.scope() == "shared.rsram" && dst.scope() == "shared.asram")
scope_check = true;
// 2.6 RSRAM <-> RSRAM
if (src.scope() == "shared.rsram" && dst.scope() == "shared.rsram")
scope_check = true;
// 2.7 RSRAM -> DRAM
if (src.scope() == "shared.rsram" && dst.scope() == "global")
scope_check = true;
if (!scope_check)
return false;

// 3. src and dst must have the same dtype
if (src->dtype != dst->dtype) {
LOG(WARNING) << "src and dst must have the same dtype for dma copy "
<< src->name << " vs. " << dst->name << " dtype " << src->dtype
<< " vs. " << dst->dtype << " will be fallback to normal copy";
return false;
}
return true;
}

/**
* @brief Selects the most specific copy instruction supported for the given
* target and buffers.
Expand Down Expand Up @@ -848,6 +931,12 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
return CopyInst::kTMemLoad;
} else if (CheckTMemStore(target)) {
return CopyInst::kTMemStore;
} else if (TargetIsSunmmio(target)) {
auto is_copy = CheckSunmmioDMACopy(target);
if (is_copy)
return CopyInst::kSunmmioDMACopy;
ICHECK(0) << "Unsupported copy from " << src.scope() << " to "
<< dst.scope() << " of Sunmmio target.";
} else {
return CopyInst::kNormal;
}
Expand All @@ -860,6 +949,7 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
* determined copy instruction type:
* - Bulk Load/Store: Uses Tensor Memory Accelerator (TMA) instructions
* - LDSM/STSM: Uses matrix load/store instructions for tensor cores
* - DMA copy: Sunmmio specified instructions for copy
* - Normal: Uses standard load/store operations with loop transformations
* \param T LowerArgs containing target and layout map.
* \param analyzer Arithmetic analyzer for simplification.
Expand Down Expand Up @@ -894,11 +984,37 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return ldsm_copy;
} else if (copy_inst == CopyInst::kNormal) {
return LowerNormalCopy(T, analyzer);
} else if (copy_inst == CopyInst::kSunmmioDMACopy) {
auto dma_copy = LowerSunmmioDmaCopy(T, analyzer);
ICHECK(dma_copy.defined()) << "Failed to lower dma copy";
return dma_copy;
} else {
LOG(FATAL) << "Unsupported copy inst " << static_cast<int>(copy_inst);
}
}

/**
* @brief Lower the copy operator for the SUNMMIO target.
*
* Emits a `tl.dma_copy(src_region, dst_region)` intrinsic call that preserves
* full buffer region semantics (buffer identity, per-axis min/extent, and
* memory scope). This intrinsic is consumed by later SUNMMIO-specific codegen
* passes to generate actual DMA instructions.
*
* @param T Lowering context (target, layout map, etc.).
* @param analyzer Arithmetic analyzer (unused here but kept for interface
* consistency).
* @return Stmt An Evaluate wrapping the tl.dma_copy Call.
*/
Stmt CopyNode::LowerSunmmioDmaCopy(const LowerArgs &T,
arith::Analyzer *analyzer) const {
// access_mask: 1=read for src, 2=write for dst
PrimExpr src_region = MakeRegionExpr(src, src_range, /*access_mask=*/1);
PrimExpr dst_region = MakeRegionExpr(dst, dst_range, /*access_mask=*/2);
return Evaluate(
Call(DataType::Handle(), dma_copy(), {src_region, dst_region}));
}

/**
* @brief Lower the copy operator using the generic (non-specialized) path.
*
Expand Down
22 changes: 18 additions & 4 deletions src/op/copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ enum class CopyInst : uint8_t {
kBulkStore = 4, // utilize tma store
// we should separate the bulk load and store for 1d and multi-dim
// as they have different memory access patterns
kBulkLoad1D = 5, // utilize tma load 1d
kBulkStore1D = 6, // utilize tma store 1d
kTMemLoad = 7, // tcgen05.ld (tensor memory -> register)
kTMemStore = 8, // tcgen05.st (register -> tensor memory)
kBulkLoad1D = 5, // utilize tma load 1d
kBulkStore1D = 6, // utilize tma store 1d
kTMemLoad = 7, // tcgen05.ld (tensor memory -> register)
kTMemStore = 8, // tcgen05.st (register -> tensor memory)
kSunmmioDMACopy = 9, // Sunmmio DMA
};

/// Descriptor for Tensor Memory Access (TMA) copy operations
Expand Down Expand Up @@ -180,6 +181,11 @@ class CopyNode : public TileOperatorNode {
*/
bool CheckTMemStore(Target target) const;

/*!
* \brief Check if Sunmmio dma copy is supported.
*/
bool CheckSunmmioDMACopy(Target target) const;

/*!
* \brief Get the copy instruction type.
*/
Expand Down Expand Up @@ -217,6 +223,14 @@ class CopyNode : public TileOperatorNode {
*/
Stmt LowerNormalCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;

/*!
* \brief Generate lowering for SUNMMIO DMA copy.
*
* Emits a tl.dma_copy(src_region, dst_region) intrinsic that preserves full
* buffer region semantics for later SUNMMIO codegen consumption.
*/
Stmt LowerSunmmioDmaCopy(const LowerArgs &T, arith::Analyzer *analyzer) const;

/*!
* \brief Generate SIMT (thread-level) loop for copying.
*/
Expand Down
41 changes: 41 additions & 0 deletions src/op/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,47 @@ BufferRegion NormalizeToBufferRegion(const PrimExpr &arg) {
throw; // Unreachable
}

/*!
* \brief Encode a Buffer + Array<Range> into a tl.tileop.region Call
* expression.
*
* This is the inverse of NormalizeToBufferRegion: it packs buffer region
* metadata into a PrimExpr so it can travel through Call arguments (where
* BufferRegion cannot appear directly).
*
* Use this when emitting intrinsic calls (e.g. tl.dma_copy) that need to
* carry full region semantics — buffer identity, per-axis min/extent, and
* access mode — as opaque PrimExpr arguments for later codegen consumption.
*
* Encoding layout:
* args[0] = BufferLoad(buffer, {range[0].min, range[1].min, ...})
* args[1] = access_mask (1=read, 2=write, 3=read-write)
* args[2+i] = range[i].extent
*
* \param buffer The buffer this region refers to.
* \param ranges Per-axis [min, extent) ranges describing the tile.
* \param access_mask 1=read, 2=write, 3=read-write.
* \return A Call(tl.tileop.region, ...) expression.
*/
PrimExpr MakeRegionExpr(const Buffer &buffer, const Array<Range> &ranges,
int access_mask) {
// Build BufferLoad with indices = per-axis minima
Array<PrimExpr> indices;
for (const auto &r : ranges) {
indices.push_back(r->min);
}
BufferLoad load(buffer, indices);

// Pack args: [load, access_mask, extent_0, extent_1, ...]
Array<PrimExpr> args;
args.push_back(load);
args.push_back(IntImm(DataType::Int(32), access_mask));
for (const auto &r : ranges) {
args.push_back(r->extent);
}
return Call(DataType::Handle(), RegionOp::Get(), args);
}

PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region, int rw_mask,
bool require_2d) {
Buffer buf = region->buffer;
Expand Down
7 changes: 7 additions & 0 deletions src/op/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ using namespace tir;
// Note: tvm_access_ptr is no longer supported here.
TVM_DLL BufferRegion NormalizeToBufferRegion(const PrimExpr &arg);

// Build a tl.tileop.region Call from a Buffer + Array<Range>.
// This is the inverse of NormalizeToBufferRegion: it packages buffer, access
// mask, and per-axis extents into a Call(RegionOp::Get(), ...) that can be
// passed as an argument to builtins like dma_copy.
TVM_DLL PrimExpr MakeRegionExpr(const Buffer &buffer,
const Array<Range> &ranges, int access_mask);

// Build a tvm_access_ptr(handle) from a BufferRegion.
// - If `require_2d` is true, checks buffer ndim >= 2.
// - For 1D regions (when allowed), offset=min, extent=extent.
Expand Down
7 changes: 5 additions & 2 deletions src/transform/lower_tile_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "../op/gemm.h"
#include "../op/gemm_sp.h"
#include "../op/operator.h"
#include "../target/utils.h"
#include "common/remap_buffer_rewriter.h"

#include "arith/ir_mutator_with_analyzer.h"
Expand Down Expand Up @@ -157,8 +158,10 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
.as<Map<Buffer, Layout>>()
.value();
for (auto [buffer, layout] : layout_map) {
buffer_remap_.Set(buffer,
makeBufferWithLayout(buffer, layout, var_remap_));
if (!TargetIsSunmmio(target_)) {
buffer_remap_.Set(buffer,
makeBufferWithLayout(buffer, layout, var_remap_));
}
layout_map_.Set(buffer, layout);
}
}
Expand Down
Loading