diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index dacb9fd20bc..1c6d4688f01 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -2152,6 +2152,10 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { dataTypeSize(ldst->out()->dtype()) == 4, "For now, we only support 32-bit types in tmem"); // TODO: hard code size 1 for now. + // According to the specification of tcgen05.{ld,st}, the register + // operand must be viewed as a vector of 32-bit elements. + // See: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-memory-and-register-load-store-instructions as_type = ArrayType{std::make_shared(ldst->in()->dtype()), 1}; } if (auto tv = dynamic_cast(ldst->in());