diff --git a/CMakeLists.txt b/CMakeLists.txt index a61da8afa33..13b474b1ee6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -95,6 +95,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/device_lower/analysis/index_compute.cpp ${NVFUSER_SRCS_DIR}/device_lower/analysis/predicate_elimination.cpp ${NVFUSER_SRCS_DIR}/device_lower/analysis/sync_information.cpp + ${NVFUSER_SRCS_DIR}/device_lower/analysis/tensor_memory.cpp ${NVFUSER_SRCS_DIR}/device_lower/analysis/thread_predicate.cpp ${NVFUSER_SRCS_DIR}/device_lower/analysis/tma.cpp ${NVFUSER_SRCS_DIR}/device_lower/analysis/trivial_broadcast.cpp diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index d9a5ba6bdcd..3b5ca0847b0 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -683,6 +683,11 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { return; } + if (ti->view()->getMemoryType() == MemoryType::Tensor) { + code_ << genInline(ti->index()); + return; + } + if (ti->view()->getMemoryType() == MemoryType::Global && kernel_->summary().sync_map->needsRawSync(ti->view()).hasBID()) { code_ << "*(volatile " << ti->getDataType().value() << "*)&"; @@ -3186,7 +3191,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { break; } case MemoryType::Tensor: { - NVF_THROW("Not implemented yet"); + // Do nothing for now. This behavior will change soon. break; } default: diff --git a/csrc/device_lower/analysis/tensor_memory.cpp b/csrc/device_lower/analysis/tensor_memory.cpp new file mode 100644 index 00000000000..2b52fd15bbd --- /dev/null +++ b/csrc/device_lower/analysis/tensor_memory.cpp @@ -0,0 +1,26 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include +#include +#include + +namespace nvfuser { + +TensorMemoryInfo computeTMemInfo(Fusion* fusion) { + bool found = false; + for (auto tv : fusion->allTvs()) { + if (tv->getMemoryType() == MemoryType::Tensor) { + NVF_ERROR(!found, "Only one tensor on TMem is supported"); + found = true; + } + } + return {}; +} + +} // namespace nvfuser diff --git a/csrc/device_lower/analysis/tensor_memory.h b/csrc/device_lower/analysis/tensor_memory.h new file mode 100644 index 00000000000..9038e171839 --- /dev/null +++ b/csrc/device_lower/analysis/tensor_memory.h @@ -0,0 +1,65 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +namespace nvfuser { + +class Fusion; + +// Information used to lower tensor memory. So far, there is no information +// needed, the computeTMemInfo just check that there is only one tensor on TMem +// in the fusion. This limitation is described in the note below, and it is only +// for incremental development. This limitation will be removed soon in the +// future. +struct TensorMemoryInfo; +TensorMemoryInfo computeTMemInfo(Fusion* fusion); + +// Note: [Tensor Memory Allocation] +// +// Tensor memory is a very special memory, so its allocation is also very +// different from other memory types. +// +// It is highly recommended to read the PTX documentation for tensor memory +// if you are not alreay familiar with it: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-memory +// +// The first thing to note is, TMem does not have virtualization. This means: +// We can not just allocate starting from address 0 like how we allocate shared +// memory, and rely on page table to translate the same virtual address of +// different CTA to different physical address. There is no virtual TMem +// address. All addresses are physical addresses. +// +// Because multiple CTAs can execute on the same SM simultaneously, there must +// be some handshaking mechanism for each CTA to know the region of TMem that it +// can use. This is done by using the PTX instruction tcgen05.alloc. To ensure +// safety, there is a mutex "I have the right to allocate TMem" in the +// hardware. At the beginning of each CTA, the CTA will try to acquire the mutex +// automatically. If it fails, the CTA will be blocked until the mutex is free. +// This means, only one CTA can allocate TMem at a time. Once the CTA has +// finished allocating TMem, it should release the mutex to relinquish the right +// to allocate. After the right to allocate is relinquished, this CTA can not +// allocate new TMem any more, but it can still access the TMem that it has +// allocated, and it can also free the TMem that it has allocated. Once one CTA +// relinquishes the right to allocate, the next CTA that is blocked will be +// unblocked and can acquire the mutex to allocate TMem. +// +// Currently, the TMem allocation is not supported in nvFuser. We currently only +// allow one TensorView to be on TMem, and because we never relinquish the right +// to allocate TMem, CTA will be serialized on SM. A new CTA can be scheduled on +// an SM only after the previous CTA on that SM has completely finished +// executing. Thanks to this serialization, we can just skip allocating and +// think that our only TMem TensorView own the entire TMem, because we are sure +// that there will not be another CTA using that address. As a result, we could +// just provide address 0 to our instructions that access TMem. In principle, it +// is clearly wrong to write to an address that is not allocated, but because we +// are sure that it will in practice work for the specific unit test that we are +// targeting, we just do it so we have incremental development. + +struct TensorMemoryInfo {}; + +} // namespace nvfuser diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 51107107192..edff33d8e68 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -599,6 +599,9 @@ void GpuLower::analysis(Fusion* fusion) { consumerToTMAInfo() = getConsumerToTMAInfoMap(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "getConsumerToTMAInfoMap"); + + tmemInfo() = computeTMemInfo(fusion_); + dumpExprsIfEnabled(fusion_->exprs(), "computeTMemInfo"); } kir::Kernel* GpuLower::kernel() const { diff --git a/csrc/device_lower/lower2device.h b/csrc/device_lower/lower2device.h index 7e35384481b..4d101824fad 100644 --- a/csrc/device_lower/lower2device.h +++ b/csrc/device_lower/lower2device.h @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -268,6 +269,14 @@ class GpuLower : public NonCopyable { return consumer_to_tma_info_; } + const TensorMemoryInfo& tmemInfo() const { + return tmem_info_; + } + + TensorMemoryInfo& tmemInfo() { + return tmem_info_; + } + // Register a boolean Val as a predicate to validate at the run time. Optional // validation error messages can be given as args. template @@ -365,6 +374,9 @@ class GpuLower : public NonCopyable { // Keep track of the mbarrier used for each load/store operation std::unordered_map ldst_mbarrier_map_; + // Information about tensor memory usage + TensorMemoryInfo tmem_info_; + // Keep track of validations needed at runtime. For example, a pair of //! "extent mod split_factor == 0" and an error message for divisibility check //! for vectorization. diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 17c34733a79..1c6d4688f01 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -2141,13 +2141,47 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { } if (!ir_utils::isStMatrixOp(ldst)) { - in = lowerSrcIndex( - ldst->in(), - ldst->out(), - {}, - ir_utils::isLdMatrixOp(ldst) || ir_utils::isCpAsyncOp(ldst)); - out = - lowerDstIndex(ldst->out(), {}, ir_utils::isCpAsyncOp(ldst), as_type); + bool is_ldst_tmem = ldst->opType() == LoadStoreOpType::LdTMem || + ldst->opType() == LoadStoreOpType::StTMem; + if (is_ldst_tmem) { + // TODO: support other types + NVF_ERROR( + dataTypeSize(ldst->in()->dtype()) == 4, + "For now, we only support 32-bit types in tmem"); + NVF_ERROR( + dataTypeSize(ldst->out()->dtype()) == 4, + "For now, we only support 32-bit types in tmem"); + // TODO: hard code size 1 for now. + // According to the specification of tcgen05.{ld,st}, the register + // operand must be viewed as a vector of 32-bit elements. + // See: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-memory-and-register-load-store-instructions + as_type = ArrayType{std::make_shared(ldst->in()->dtype()), 1}; + } + if (auto tv = dynamic_cast(ldst->in()); + tv != nullptr && tv->getMemoryType() == MemoryType::Tensor) { + // TODO: hard coded index zero for now. + auto index = IrBuilder::create(0, DataType::UInt32); + in = IrBuilder::create( + tv, index, DataType::TMemAddress); + } else { + in = lowerSrcIndex( + ldst->in(), + ldst->out(), + {}, + ir_utils::isLdMatrixOp(ldst) || ir_utils::isCpAsyncOp(ldst), + as_type); + } + if (auto tv = dynamic_cast(ldst->out()); + tv != nullptr && tv->getMemoryType() == MemoryType::Tensor) { + // TODO: hard coded index zero for now. + auto index = IrBuilder::create(0, DataType::UInt32); + out = IrBuilder::create( + tv, index, DataType::TMemAddress); + } else { + out = lowerDstIndex( + ldst->out(), {}, ir_utils::isCpAsyncOp(ldst), as_type); + } } auto new_ldst = IrBuilder::create(ldst->opType(), out, in, ldst->cacheOp()) diff --git a/csrc/device_lower/pass/inline_ptx.cpp b/csrc/device_lower/pass/inline_ptx.cpp index c27fd5294f6..44ee4223167 100644 --- a/csrc/device_lower/pass/inline_ptx.cpp +++ b/csrc/device_lower/pass/inline_ptx.cpp @@ -119,6 +119,41 @@ class LowerToInlinePtx : public kir::ExprMutator { IrBuilder::create(vec_size), invertedPredicate(ldst->predicate())}, kir::Asm::Options{/*volatile=*/true})); + } else if (ldst->opType() == LoadStoreOpType::LdTMem) { + // TODO: support other types of ld/st + auto ptx = "tcgen05.ld.sync.aligned.32x32b.x1.b32"; + registerReplace( + ldst, + IrBuilder::create( + ptx, + std::vector{ldst->out()}, + std::vector{ldst->in()})); + auto wait_ptx = "tcgen05.wait::ld.sync.aligned"; + registerInsertAfter( + ldst, + IrBuilder::create( + wait_ptx, + std::vector{}, + std::vector{}, + kir::Asm::Options{/*volatile=*/true})); + } else if (ldst->opType() == LoadStoreOpType::StTMem) { + // TODO: support other types of ld/st + auto ptx = "tcgen05.st.sync.aligned.32x32b.x1.b32"; + registerReplace( + ldst, + IrBuilder::create( + ptx, + std::vector{}, + std::vector{ldst->out(), ldst->in()}, + kir::Asm::Options{/*volatile=*/true})); + auto wait_ptx = "tcgen05.wait::st.sync.aligned"; + registerInsertAfter( + ldst, + IrBuilder::create( + wait_ptx, + std::vector{}, + std::vector{}, + kir::Asm::Options{/*volatile=*/true})); } } diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index 49ce8f820c8..ea5c5441985 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -100,7 +100,9 @@ TensorIndex::TensorIndex( isPointerType(index->dtype()) || index->dtype() == DataType::Index || isStructType(index->dtype()) || index->dtype() == - DataType::UInt64 /*For matrix descriptor for hopper MMA*/, + DataType::UInt64 /*For matrix descriptor for hopper MMA*/ + || index->dtype() == + DataType::UInt32 /*Temporarily enabled for TMem tensor*/, "Cannot index with a value other than an int/pointer/struct."); } diff --git a/csrc/type.cpp b/csrc/type.cpp index 06056f6f927..eaf18ae5e2a 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -243,11 +243,11 @@ static std::string data_type2string(DataType t) { case DataType::UInt16: return "uint16_t"; case DataType::UInt32: + case DataType::SMemAddress: + case DataType::TMemAddress: return "uint32_t"; case DataType::UInt64: return "uint64_t"; - case DataType::SMemAddress: - return "unsigned"; case DataType::ComplexFloat: return "std::complex"; case DataType::ComplexDouble: @@ -860,6 +860,10 @@ const char* load_store_type2string(LoadStoreOpType t) { return "CpAsyncBulk"; case LoadStoreOpType::CpAsyncBulkTensorTile: return "CpAsyncBulkTensorTile"; + case LoadStoreOpType::LdTMem: + return "LdTMem"; + case LoadStoreOpType::StTMem: + return "StTMem"; default: NVF_THROW("Unexpected parallel type"); } diff --git a/csrc/type.h b/csrc/type.h index 388d0bb05b2..cf577bee188 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -88,6 +88,7 @@ enum class PrimDataType { ComplexFloat, // Pointers SMemAddress, + TMemAddress, // Null Null }; @@ -196,6 +197,7 @@ struct DataType { static constexpr PrimDataType ComplexFloat = PrimDataType::ComplexFloat; static constexpr PrimDataType ComplexDouble = PrimDataType::ComplexDouble; static constexpr PrimDataType SMemAddress = PrimDataType::SMemAddress; + static constexpr PrimDataType TMemAddress = PrimDataType::TMemAddress; static constexpr PrimDataType Null = PrimDataType::Null; }; @@ -297,7 +299,7 @@ inline bool isUnsignedIntegralType(DataType dtype) { // Returns if the datatype is a pointer type inline bool isPointerType(DataType dtype) { return std::holds_alternative(dtype.type) || - dtype == DataType::SMemAddress; + dtype == DataType::SMemAddress || dtype == DataType::TMemAddress; } // Returns if the datatype is an integer or pointer type @@ -801,7 +803,9 @@ enum class LoadStoreOpType { CpAsync, CpAsyncBulk, CpAsyncBulkTensorTile, - StMatrix + StMatrix, + LdTMem, + StTMem }; // Used to label what part of the circular buffered iterdomain @@ -1055,11 +1059,11 @@ constexpr inline size_t primDataTypeSize(PrimDataType type) { case DataType::UInt16: return sizeof(uint16_t); case DataType::UInt32: + case DataType::SMemAddress: + case DataType::TMemAddress: return sizeof(uint32_t); case DataType::UInt64: return sizeof(uint64_t); - case DataType::SMemAddress: - return sizeof(unsigned); default: NVF_THROW("Size undefined for data type."); } diff --git a/tests/cpp/test_loop_rotation.cpp b/tests/cpp/test_loop_rotation.cpp index d8ae9d49bda..9fb4c440070 100644 --- a/tests/cpp/test_loop_rotation.cpp +++ b/tests/cpp/test_loop_rotation.cpp @@ -568,7 +568,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor const unsigned smem_offset = 0; NVFUSER_DEFINE_MAGIC_ZERO; float* T4 = reinterpret_cast(array + smem_offset + 0LL); - unsigned i0; + uint32_t i0; i0 = toSmem(T4); float* ptr1; ptr1 = T0.data + (4LL * T0.alloc_stride[0LL]); @@ -576,7 +576,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor for(nvfuser_index_t i2 = 0LL; i2 < 4LL; ++i2) { float* ptr3; ptr3 = T0.data + (T0.alloc_stride[0LL] * i2); - unsigned i4; + uint32_t i4; i4 = i0 + (12LL * i2); bool b5; b5 = (i2 + nvfuser_zero) < T0.logical_size[0LL]; @@ -608,7 +608,7 @@ __global__ void CUDAGeneratedKernel(Tensor T0, Tensor ptr8 = ptr1 + (T0.alloc_stride[0LL] * i7); nvfuser_index_t i9; i9 = 4LL + i7; - unsigned i10; + uint32_t i10; i10 = i0 + (12LL * (i9 % 5LL)); nvfuser_index_t i11; i11 = 1LL + (3LL * (i7 % 5LL)); diff --git a/tests/cpp/test_memory.cpp b/tests/cpp/test_memory.cpp index d25a8420440..bfdb4de4807 100644 --- a/tests/cpp/test_memory.cpp +++ b/tests/cpp/test_memory.cpp @@ -2761,6 +2761,45 @@ TEST_F(TMADocTest, Figure15e) { // End TMA tests +// Tensor memory tests +using TMemTest = BlackwellBase; + +TEST_F(TMemTest, GmemRegTMemRegGmemCopy) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = set(tv0); // register + auto tv2 = set(tv1); // tmem + auto tv3 = set(tv2); // register + auto tv4 = set(tv3); // gmem + fusion.addOutput(tv4); + + tv2->setMemoryType(MemoryType::Tensor); + tv2->definition()->as()->setOpType(LoadStoreOpType::StTMem); + tv3->definition()->as()->setOpType(LoadStoreOpType::LdTMem); + + tv4->split(0, 32); + + TransformPropagator propagator(tv4); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); + + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(tv4, {tv1, tv2, tv3}); + + inlineMost(); + + KernelExecutor ke; + ke.compile(&fusion); + auto t0 = at::randn( + {12800}, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0)); + auto cg_outputs = ke.run({t0}); + testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); +} + using LdMatrixTestParam = std::tuple; class LdMatrixTest : public NVFuserFixtureParamTest { diff --git a/tests/cpp/utils.h b/tests/cpp/utils.h index 6bfb74ac1c5..63ffdaaffc4 100644 --- a/tests/cpp/utils.h +++ b/tests/cpp/utils.h @@ -620,6 +620,16 @@ class HopperBase : public NVFuserTest { } }; +class BlackwellBase : public NVFuserTest { + protected: + void SetUp() override { + if (cudaArchGuardShouldSkip(10, 0)) { + GTEST_SKIP() << "skipping tests on non-Blackwell GPUs"; + } + NVFuserTest::SetUp(); + } +}; + // TMA is supported on Hopper and newer GPUs class TmaBase : public NVFuserTest { protected: