diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 98bd5a1624f..a0671795e3a 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -3183,7 +3183,12 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { if (va.find(tv) != va.end()) { aligned_array_of_regs_.insert(tv); } - } break; + break; + } + case MemoryType::Tensor: { + NVF_THROW("Not implemented yet"); + break; + } default: NVF_THROW("Unexpected memory type"); } diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index 8af6be1b7f1..6896288fc5b 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -713,6 +713,7 @@ inline bool isMemoryPartitionedAcross( return isParallelTypeThread(parallel_type) || isParallelTypeDeviceDim(parallel_type); case MemoryType::Shared: + case MemoryType::Tensor: return isParallelTypeBlockDim(parallel_type) || isParallelTypeDeviceDim(parallel_type); case MemoryType::Global: @@ -732,7 +733,8 @@ inline bool isMemorySharedAcross( // Nothing is shared if it's Local return false; case MemoryType::Shared: - // Only TID parallelized domains are shared if it's Shared + case MemoryType::Tensor: + // Only TID parallelized domains are shared if it's Shared or Tensor return isParallelTypeThreadDim(parallel_type); case MemoryType::Global: // Only TID and BID parallelized domains are shared if it's Global diff --git a/csrc/kernel.cpp b/csrc/kernel.cpp index caf23e9df6a..9f255388856 100644 --- a/csrc/kernel.cpp +++ b/csrc/kernel.cpp @@ -91,6 +91,8 @@ class KernelIrScanner : private IrVisitor { summary_.dynamic_lmem_allocations.emplace_back(allocate); } break; + case MemoryType::Tensor: + break; default: NVF_THROW("Unknown memory type to allocate."); } diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index f04eee6a7d9..69d05bdf230 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -60,6 +60,9 @@ std::string TensorView::toString(int indent_size) const { case MemoryType::Local: ss << "_l"; break; + case MemoryType::Tensor: + ss << "_t"; + break; default: NVF_THROW("Unknown tensor memory type."); } diff --git a/csrc/type.cpp b/csrc/type.cpp index c0870bc3332..06056f6f927 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -774,6 +774,8 @@ static const char* memory_type2string(MemoryType t) { return "shared"; case MemoryType::Global: return "global"; + case MemoryType::Tensor: + return "tensor"; default: NVF_THROW("Unexpected MemoryType"); } diff --git a/csrc/type.h b/csrc/type.h index 3eb2a286f24..388d0bb05b2 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -742,7 +742,7 @@ static constexpr std::array kParallelTypeTIDs = { static constexpr std::array kParallelTypeDIDs = { ParallelType::DIDx}; -enum class MemoryType { Local, Shared, Global }; +enum class MemoryType { Local, Shared, Global, Tensor }; // Symbolic: Undetermined between Iteration or Broadcast enum class IterType {