Skip to content

Naive register <-> tmem load/store support #3786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/device_lower/analysis/index_compute.cpp
${NVFUSER_SRCS_DIR}/device_lower/analysis/predicate_elimination.cpp
${NVFUSER_SRCS_DIR}/device_lower/analysis/sync_information.cpp
${NVFUSER_SRCS_DIR}/device_lower/analysis/tensor_memory.cpp
${NVFUSER_SRCS_DIR}/device_lower/analysis/thread_predicate.cpp
${NVFUSER_SRCS_DIR}/device_lower/analysis/tma.cpp
${NVFUSER_SRCS_DIR}/device_lower/analysis/trivial_broadcast.cpp
Expand Down
7 changes: 6 additions & 1 deletion csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,11 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
return;
}

if (ti->view()->getMemoryType() == MemoryType::Tensor) {
code_ << genInline(ti->index());
return;
}

if (ti->view()->getMemoryType() == MemoryType::Global &&
kernel_->summary().sync_map->needsRawSync(ti->view()).hasBID()) {
code_ << "*(volatile " << ti->getDataType().value() << "*)&";
Expand Down Expand Up @@ -3186,7 +3191,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
break;
}
case MemoryType::Tensor: {
NVF_THROW("Not implemented yet");
// Do nothing for now. This behavior will change soon.
break;
}
default:
Expand Down
26 changes: 26 additions & 0 deletions csrc/device_lower/analysis/tensor_memory.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on

#include <device_lower/analysis/tensor_memory.h>
#include <fusion.h>
#include <ir/all_nodes.h>

namespace nvfuser {

TensorMemoryInfo computeTMemInfo(Fusion* fusion) {
bool found = false;
for (auto tv : fusion->allTvs()) {
if (tv->getMemoryType() == MemoryType::Tensor) {
NVF_ERROR(!found, "Only one tensor on TMem is supported");
found = true;
}
}
return {};
}

} // namespace nvfuser
65 changes: 65 additions & 0 deletions csrc/device_lower/analysis/tensor_memory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#pragma once

namespace nvfuser {

class Fusion;

// Information used to lower tensor memory. So far, there is no information
// needed, the computeTMemInfo just check that there is only one tensor on TMem
// in the fusion. This limitation is described in the note below, and it is only
// for incremental development. This limitation will be removed soon in the
// future.
struct TensorMemoryInfo;
TensorMemoryInfo computeTMemInfo(Fusion* fusion);

// Note: [Tensor Memory Allocation]
//
// Tensor memory is a very special memory, so its allocation is also very
// different from other memory types.
//
// It is highly recommended to read the PTX documentation for tensor memory
// if you are not alreay familiar with it:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-memory
//
// The first thing to note is, TMem does not have virtualization. This means:
// We can not just allocate starting from address 0 like how we allocate shared
// memory, and rely on page table to translate the same virtual address of
// different CTA to different physical address. There is no virtual TMem
// address. All addresses are physical addresses.
//
// Because multiple CTAs can execute on the same SM simultaneously, there must
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to this handshaking mechanism, is it better to have only a single CTA occupy an SM?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you talking about kernel design for better perf? My guess is, if you allocate at the beginning of the kernel, and relinquish after allocate, the latency should be acceptable if you want to use multiple CTA on SM. But we need to test it before making any conclusion.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for maximum performance.

// be some handshaking mechanism for each CTA to know the region of TMem that it
// can use. This is done by using the PTX instruction tcgen05.alloc. To ensure
// safety, there is a mutex "I have the right to allocate TMem" in the
// hardware. At the beginning of each CTA, the CTA will try to acquire the mutex
// automatically. If it fails, the CTA will be blocked until the mutex is free.
// This means, only one CTA can allocate TMem at a time. Once the CTA has
// finished allocating TMem, it should release the mutex to relinquish the right
// to allocate. After the right to allocate is relinquished, this CTA can not
// allocate new TMem any more, but it can still access the TMem that it has
// allocated, and it can also free the TMem that it has allocated. Once one CTA
// relinquishes the right to allocate, the next CTA that is blocked will be
// unblocked and can acquire the mutex to allocate TMem.
//
// Currently, the TMem allocation is not supported in nvFuser. We currently only
// allow one TensorView to be on TMem, and because we never relinquish the right
// to allocate TMem, CTA will be serialized on SM. A new CTA can be scheduled on
// an SM only after the previous CTA on that SM has completely finished
// executing. Thanks to this serialization, we can just skip allocating and
// think that our only TMem TensorView own the entire TMem, because we are sure
// that there will not be another CTA using that address. As a result, we could
// just provide address 0 to our instructions that access TMem. In principle, it
// is clearly wrong to write to an address that is not allocated, but because we
// are sure that it will in practice work for the specific unit test that we are
// targeting, we just do it so we have incremental development.

struct TensorMemoryInfo {};

} // namespace nvfuser
3 changes: 3 additions & 0 deletions csrc/device_lower/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,9 @@ void GpuLower::analysis(Fusion* fusion) {

consumerToTMAInfo() = getConsumerToTMAInfoMap(fusion_);
dumpExprsIfEnabled(fusion_->exprs(), "getConsumerToTMAInfoMap");

tmemInfo() = computeTMemInfo(fusion_);
dumpExprsIfEnabled(fusion_->exprs(), "computeTMemInfo");
}

kir::Kernel* GpuLower::kernel() const {
Expand Down
12 changes: 12 additions & 0 deletions csrc/device_lower/lower2device.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <device_lower/analysis/fused_reduction.h>
#include <device_lower/analysis/predicate_elimination.h>
#include <device_lower/analysis/sync_information.h>
#include <device_lower/analysis/tensor_memory.h>
#include <device_lower/analysis/thread_predicate.h>
#include <device_lower/analysis/tma.h>
#include <device_lower/analysis/trivial_broadcast.h>
Expand Down Expand Up @@ -268,6 +269,14 @@ class GpuLower : public NonCopyable {
return consumer_to_tma_info_;
}

const TensorMemoryInfo& tmemInfo() const {
return tmem_info_;
}

TensorMemoryInfo& tmemInfo() {
return tmem_info_;
}

// Register a boolean Val as a predicate to validate at the run time. Optional
// validation error messages can be given as args.
template <typename... Args>
Expand Down Expand Up @@ -365,6 +374,9 @@ class GpuLower : public NonCopyable {
// Keep track of the mbarrier used for each load/store operation
std::unordered_map<const Expr*, TensorView*> ldst_mbarrier_map_;

// Information about tensor memory usage
TensorMemoryInfo tmem_info_;

// Keep track of validations needed at runtime. For example, a pair of
//! "extent mod split_factor == 0" and an error message for divisibility check
//! for vectorization.
Expand Down
48 changes: 41 additions & 7 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2141,13 +2141,47 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
}

if (!ir_utils::isStMatrixOp(ldst)) {
in = lowerSrcIndex(
ldst->in(),
ldst->out(),
{},
ir_utils::isLdMatrixOp(ldst) || ir_utils::isCpAsyncOp(ldst));
out =
lowerDstIndex(ldst->out(), {}, ir_utils::isCpAsyncOp(ldst), as_type);
bool is_ldst_tmem = ldst->opType() == LoadStoreOpType::LdTMem ||
ldst->opType() == LoadStoreOpType::StTMem;
if (is_ldst_tmem) {
// TODO: support other types
NVF_ERROR(
dataTypeSize(ldst->in()->dtype()) == 4,
"For now, we only support 32-bit types in tmem");
NVF_ERROR(
dataTypeSize(ldst->out()->dtype()) == 4,
"For now, we only support 32-bit types in tmem");
// TODO: hard code size 1 for now.
// According to the specification of tcgen05.{ld,st}, the register
// operand must be viewed as a vector of 32-bit elements.
// See:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-memory-and-register-load-store-instructions
as_type = ArrayType{std::make_shared<DataType>(ldst->in()->dtype()), 1};
}
if (auto tv = dynamic_cast<TensorView*>(ldst->in());
tv != nullptr && tv->getMemoryType() == MemoryType::Tensor) {
// TODO: hard coded index zero for now.
auto index = IrBuilder::create<Val>(0, DataType::UInt32);
in = IrBuilder::create<kir::TensorIndex>(
tv, index, DataType::TMemAddress);
} else {
in = lowerSrcIndex(
ldst->in(),
ldst->out(),
{},
ir_utils::isLdMatrixOp(ldst) || ir_utils::isCpAsyncOp(ldst),
as_type);
}
if (auto tv = dynamic_cast<TensorView*>(ldst->out());
tv != nullptr && tv->getMemoryType() == MemoryType::Tensor) {
// TODO: hard coded index zero for now.
auto index = IrBuilder::create<Val>(0, DataType::UInt32);
out = IrBuilder::create<kir::TensorIndex>(
tv, index, DataType::TMemAddress);
} else {
out = lowerDstIndex(
ldst->out(), {}, ir_utils::isCpAsyncOp(ldst), as_type);
}
}
auto new_ldst =
IrBuilder::create<LoadStoreOp>(ldst->opType(), out, in, ldst->cacheOp())
Expand Down
35 changes: 35 additions & 0 deletions csrc/device_lower/pass/inline_ptx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,41 @@ class LowerToInlinePtx : public kir::ExprMutator {
IrBuilder::create<Val>(vec_size),
invertedPredicate(ldst->predicate())},
kir::Asm::Options{/*volatile=*/true}));
} else if (ldst->opType() == LoadStoreOpType::LdTMem) {
// TODO: support other types of ld/st
auto ptx = "tcgen05.ld.sync.aligned.32x32b.x1.b32";
registerReplace(
ldst,
IrBuilder::create<kir::Asm>(
ptx,
std::vector<Val*>{ldst->out()},
std::vector<Val*>{ldst->in()}));
auto wait_ptx = "tcgen05.wait::ld.sync.aligned";
registerInsertAfter(
ldst,
IrBuilder::create<kir::Asm>(
wait_ptx,
std::vector<Val*>{},
std::vector<Val*>{},
kir::Asm::Options{/*volatile=*/true}));
} else if (ldst->opType() == LoadStoreOpType::StTMem) {
// TODO: support other types of ld/st
auto ptx = "tcgen05.st.sync.aligned.32x32b.x1.b32";
registerReplace(
ldst,
IrBuilder::create<kir::Asm>(
ptx,
std::vector<Val*>{},
std::vector<Val*>{ldst->out(), ldst->in()},
kir::Asm::Options{/*volatile=*/true}));
auto wait_ptx = "tcgen05.wait::st.sync.aligned";
registerInsertAfter(
ldst,
IrBuilder::create<kir::Asm>(
wait_ptx,
std::vector<Val*>{},
std::vector<Val*>{},
kir::Asm::Options{/*volatile=*/true}));
}
}

Expand Down
4 changes: 3 additions & 1 deletion csrc/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ TensorIndex::TensorIndex(
isPointerType(index->dtype()) || index->dtype() == DataType::Index ||
isStructType(index->dtype()) ||
index->dtype() ==
DataType::UInt64 /*For matrix descriptor for hopper MMA*/,
DataType::UInt64 /*For matrix descriptor for hopper MMA*/
|| index->dtype() ==
DataType::UInt32 /*Temporarily enabled for TMem tensor*/,
"Cannot index with a value other than an int/pointer/struct.");
}

Expand Down
8 changes: 6 additions & 2 deletions csrc/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,11 @@ static std::string data_type2string(DataType t) {
case DataType::UInt16:
return "uint16_t";
case DataType::UInt32:
case DataType::SMemAddress:
case DataType::TMemAddress:
return "uint32_t";
case DataType::UInt64:
return "uint64_t";
case DataType::SMemAddress:
return "unsigned";
case DataType::ComplexFloat:
return "std::complex<float>";
case DataType::ComplexDouble:
Expand Down Expand Up @@ -860,6 +860,10 @@ const char* load_store_type2string(LoadStoreOpType t) {
return "CpAsyncBulk";
case LoadStoreOpType::CpAsyncBulkTensorTile:
return "CpAsyncBulkTensorTile";
case LoadStoreOpType::LdTMem:
return "LdTMem";
case LoadStoreOpType::StTMem:
return "StTMem";
default:
NVF_THROW("Unexpected parallel type");
}
Expand Down
12 changes: 8 additions & 4 deletions csrc/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ enum class PrimDataType {
ComplexFloat,
// Pointers
SMemAddress,
TMemAddress,
// Null
Null
};
Expand Down Expand Up @@ -196,6 +197,7 @@ struct DataType {
static constexpr PrimDataType ComplexFloat = PrimDataType::ComplexFloat;
static constexpr PrimDataType ComplexDouble = PrimDataType::ComplexDouble;
static constexpr PrimDataType SMemAddress = PrimDataType::SMemAddress;
static constexpr PrimDataType TMemAddress = PrimDataType::TMemAddress;
static constexpr PrimDataType Null = PrimDataType::Null;
};

Expand Down Expand Up @@ -297,7 +299,7 @@ inline bool isUnsignedIntegralType(DataType dtype) {
// Returns if the datatype is a pointer type
inline bool isPointerType(DataType dtype) {
return std::holds_alternative<PointerType>(dtype.type) ||
dtype == DataType::SMemAddress;
dtype == DataType::SMemAddress || dtype == DataType::TMemAddress;
}

// Returns if the datatype is an integer or pointer type
Expand Down Expand Up @@ -801,7 +803,9 @@ enum class LoadStoreOpType {
CpAsync,
CpAsyncBulk,
CpAsyncBulkTensorTile,
StMatrix
StMatrix,
LdTMem,
StTMem
};

// Used to label what part of the circular buffered iterdomain
Expand Down Expand Up @@ -1055,11 +1059,11 @@ constexpr inline size_t primDataTypeSize(PrimDataType type) {
case DataType::UInt16:
return sizeof(uint16_t);
case DataType::UInt32:
case DataType::SMemAddress:
case DataType::TMemAddress:
return sizeof(uint32_t);
case DataType::UInt64:
return sizeof(uint64_t);
case DataType::SMemAddress:
return sizeof(unsigned);
default:
NVF_THROW("Size undefined for data type.");
}
Expand Down
6 changes: 3 additions & 3 deletions tests/cpp/test_loop_rotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,15 +568,15 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2>
const unsigned smem_offset = 0;
NVFUSER_DEFINE_MAGIC_ZERO;
float* T4 = reinterpret_cast<float*>(array + smem_offset + 0LL);
unsigned i0;
uint32_t i0;
i0 = toSmem(T4);
float* ptr1;
ptr1 = T0.data + (4LL * T0.alloc_stride[0LL]);
#pragma unroll 4
for(nvfuser_index_t i2 = 0LL; i2 < 4LL; ++i2) {
float* ptr3;
ptr3 = T0.data + (T0.alloc_stride[0LL] * i2);
unsigned i4;
uint32_t i4;
i4 = i0 + (12LL * i2);
bool b5;
b5 = (i2 + nvfuser_zero) < T0.logical_size[0LL];
Expand Down Expand Up @@ -608,7 +608,7 @@ __global__ void CUDAGeneratedKernel(Tensor<float, 2, 2> T0, Tensor<float, 2, 2>
ptr8 = ptr1 + (T0.alloc_stride[0LL] * i7);
nvfuser_index_t i9;
i9 = 4LL + i7;
unsigned i10;
uint32_t i10;
i10 = i0 + (12LL * (i9 % 5LL));
nvfuser_index_t i11;
i11 = 1LL + (3LL * (i7 % 5LL));
Expand Down
Loading