diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 97b130952d..661f8b00e1 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -101,6 +101,7 @@ struct SimpleTensor { } return acc; } + bool has_data() const noexcept { return dptr != nullptr && numel() > 0; } void clear() { dptr = nullptr; @@ -154,9 +155,11 @@ struct Tensor { return acc; } + // TODO(Tim): Change this to use data.has_data() bool has_data() const noexcept { return data.dptr != nullptr; } // Check for size (not just pointer) for 0-dim or no token cases. + // TODO(Tim): Change this to use columnwise_data.has_data() bool has_columnwise_data() const noexcept { return columnwise_data.dptr != nullptr || columnwise_data.shape.size() != 0; } @@ -281,6 +284,129 @@ struct Tensor { } }; +struct GroupedTensor { + public: + /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ + /* + Grouped tensor is a collection of tensors with different shapes but the same dtype and scaling mode + + Shape Representation: + - logical_shape: 2D shape representing the conceptual layouy, i.e. the shape when member tensors are flattened to 2D and stacked together (REQUIRED) + + When all_same_shape(): [num_tensors * M, N] where each tensor is (M, N) + + When varying_first_dim(): [~sum_of_first_dims, N] where N is common + + When varying_last_dim(): [M, ~sum_of_last_dims] where M is common + + When varying_both_dims(): [1, total_elements] (fully flattened) + + - first_dims and last_dims are OPTIONAL (empty if dimension is uniform) + + Empty first_dims: all tensors have the same first dimension + + Empty last_dims: all tensors have the same last dimension + + Both empty: all tensors have identical shapes + + Both set: each tensor has unique shape (first_dims[i], last_dims[i]) + + Data Layout: + - ALL data fields are stored as 1D flattened arrays (data, columnwise_data, scale_inv, etc.) + - logical_shape provides the conceptual 2D interpretation + - All data is stored on device in contiguous layout + */ + + SimpleTensor data; + SimpleTensor columnwise_data; + SimpleTensor scale_inv; + SimpleTensor columnwise_scale_inv; + SimpleTensor amax; + SimpleTensor columnwise_amax; + SimpleTensor scale; // for FP8-DS only + + // Shape information (OPTIONAL - empty if dimension is uniform across all tensors) + // first_dims[i] = first dimension of tensor i (empty if all tensors have same first dim) + // last_dims[i] = last dimension of tensor i (empty if all tensors have same last dim) + SimpleTensor first_dims; // Device pointer to int64_t array of length num_tensors (or empty) + SimpleTensor last_dims; // Device pointer to int64_t array of length num_tensors (or empty) + + // Offsets for indexing into contiguous 1D layout (OPTIONAL - not needed if all_same_shape()) + // tensor_offsets[i] = element offset to start of tensor i (cumulative sum of numel for tensors 0..i-1) + // Usage: tensor_i_ptr = (char*)data.dptr + tensor_offsets[i] * element_size + // If empty and all_same_shape(): offset[i] = i * M * N (where M, N are common dimensions) + SimpleTensor tensor_offsets; // Device pointer to int64_t array of length num_tensors (or empty) + + // Logical shape: conceptual 2D shape of the grouped data (REQUIRED) + // Represents how the 1D flattened data should be interpreted as 2D + // Always 2D with positive dimensions + NVTEShape logical_shape; + + NVTEScalingMode scaling_mode; + size_t num_tensors; + NVTEGroupedTensor nvte_tensor; + + GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors) + : data(), + columnwise_data(), + scale_inv(), + columnwise_scale_inv(), + amax(), + columnwise_amax(), + scale(), + num_tensors(num_tensors), + first_dims(nullptr, {}, DType::kInt64), + last_dims(nullptr, {}, DType::kInt64), + tensor_offsets(nullptr, {}, DType::kInt64), + logical_shape(nvte_make_shape(nullptr, 0)), + scaling_mode(scaling_mode), + nvte_tensor(0) {} + + explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; } + + bool has_data() const noexcept { return data.has_data(); } + bool has_columnwise_data() const noexcept { return columnwise_data.has_data(); } + + bool all_same_first_dim() const noexcept { return !first_dims.has_data(); } + bool all_same_last_dim() const noexcept { return !last_dims.has_data(); } + bool all_same_shape() const noexcept { return !first_dims.has_data() && !last_dims.has_data(); } + bool varying_both_dims() const noexcept { return first_dims.has_data() && last_dims.has_data(); } + + size_t get_common_first_dim() const { + NVTE_CHECK(all_same_first_dim(), "First dim varies across tensors"); + NVTE_CHECK(logical_shape.ndim == 2, "Logical shape must be 2D"); + if (all_same_shape()) { + // When both dims are uniform: logical_shape = [num_tensors * M, N] + return logical_shape.data[0] / num_tensors; + } else { + // When varying last dims but not first dim: logical_shape = [M, sum_of_last_dims] + return logical_shape.data[0]; + } + } + size_t get_common_last_dim() const { + NVTE_CHECK(all_same_last_dim(), "Last dim varies across tensors"); + NVTE_CHECK(logical_shape.ndim == 2, "Logical shape must be 2D"); + // For both uniform and varying first dim cases: logical_shape[1] is the common last dim + return logical_shape.data[1]; + } + + DType dtype() const { + if (has_data()) return data.dtype; + if (has_columnwise_data()) return columnwise_data.dtype; + // Fallback, used e.g. in workspace or when allow_empty=true + return data.dtype; + } + + void clear() { + data.clear(); + columnwise_data.clear(); + scale_inv.clear(); + columnwise_scale_inv.clear(); + amax.clear(); + columnwise_amax.clear(); + scale.clear(); + first_dims.clear(); + last_dims.clear(); + tensor_offsets.clear(); + logical_shape = nvte_make_shape(nullptr, 0); + num_tensors = 0; + scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + nvte_tensor = 0; + } +}; + struct QuantizationConfig { bool force_pow_2_scales = false; float amax_epsilon = 0.0f; @@ -779,6 +905,16 @@ std::vector> convert_tensor_array(NVTETensor **nvte_tensor Tensor *convertNVTETensor(const NVTETensor tensor); Tensor *convertNVTETensorCheck(const NVTETensor tensor); + +GroupedTensor *convertNVTEGroupedTensor(const NVTEGroupedTensor tensor); +GroupedTensor *convertNVTEGroupedTensorCheck(const NVTEGroupedTensor tensor); + +// Helper functions for GroupedTensor validation +void CheckGroupedTensorShapeArrays(const GroupedTensor &t, const std::string &name); +void CheckInputGroupedTensor(const GroupedTensor &t, const std::string &name); +void CheckOutputGroupedTensor(const GroupedTensor &t, const std::string &name, + bool allow_empty = false); + } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_ diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 1a901ab82d..76cc636a35 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -393,6 +393,114 @@ int nvte_is_non_tn_fp8_gemm_supported(); */ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream); +/*! \brief TE Grouped Tensor type + * + * NVTEGroupedTensor is a collection of tensors with potentially different shapes + * but the same dtype and scaling mode. It does not own the memory it points to. + */ +typedef void *NVTEGroupedTensor; + +/*! \enum NVTEGroupedTensorParam + * \brief Indicates the kind of the grouped tensor parameter to set/get. + */ +enum NVTEGroupedTensorParam { + kNVTEGroupedRowwiseData = 0, /*!< Data usable in rowwise manner */ + kNVTEGroupedColumnwiseData = 1, /*!< Data usable in columnwise manner */ + kNVTEGroupedScale = 2, /*!< Scale tensor */ + kNVTEGroupedAmax = 3, /*!< Amax tensor */ + kNVTEGroupedRowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */ + kNVTEGroupedColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */ + kNVTEGroupedColumnwiseAmax = 6, /*!< Columnwise Amax tensor */ + kNVTEGroupedFirstDims = 7, /*!< First dimension sizes (device pointer to int64_t array) */ + kNVTEGroupedLastDims = 8, /*!< Last dimension sizes (device pointer to int64_t array) */ + kNVTEGroupedTensorOffsets = + 9, /*!< Tensor offsets for contiguous layout (device pointer to int64_t array) */ + kNVTENumGroupedTensorParams +}; + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Create a new TE grouped tensor. + * + * Create a new TE grouped tensor. Before use its parameters need to be set. + * TE grouped tensors are just wrappers on top of raw data and do not + * own memory. + * + * \param[in] scaling_mode Scaling mode of the grouped tensor. + * \param[in] num_tensors Number of tensors in the group (must be > 0). + * \param[in] logical_shape Logical 2D shape of the grouped data. + * + * \return A new TE grouped tensor. + */ +NVTEGroupedTensor nvte_create_grouped_tensor(NVTEScalingMode scaling_mode, size_t num_tensors, + NVTEShape logical_shape); + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Destroy a TE grouped tensor. + * + * Since the TE grouped tensor does not own memory, the underlying + * data is not freed during this operation. + * + * \param[in] tensor Grouped tensor to be destroyed. + */ +void nvte_destroy_grouped_tensor(NVTEGroupedTensor tensor); + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Set a parameter of the grouped tensor. + * + * \param[in/out] tensor Grouped tensor. + * \param[in] param_name The parameter to be set. + * \param[in] param The value to be set (NVTEBasicTensor). + */ +void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorParam param_name, + const NVTEBasicTensor *param); + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Get a value of the parameter of the grouped tensor. + * + * \param[in] tensor Grouped tensor. + * \param[in] param_name The parameter to be queried. + * + * \return NVTEBasicTensor containing the parameter data. + */ +NVTEBasicTensor nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor, + NVTEGroupedTensorParam param_name); + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Get the number of tensors in a grouped tensor. + * + * \param[in] tensor Grouped tensor. + * + * \return Number of tensors in the group. + */ +size_t nvte_grouped_tensor_num_tensors(const NVTEGroupedTensor tensor); + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Get a grouped tensor's data type. + * + * \param[in] tensor Grouped tensor. + * + * \return A data type of the grouped tensor. + */ +NVTEDType nvte_grouped_tensor_type(const NVTEGroupedTensor tensor); + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Get a scaling mode of the grouped tensor. + * + * \param[in] tensor Grouped tensor. + * + * \return Scaling mode of the grouped tensor. + */ +NVTEScalingMode nvte_grouped_tensor_scaling_mode(const NVTEGroupedTensor tensor); + +/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ +/*! \brief Get the logical shape of a grouped tensor. + * + * \param[in] tensor Grouped tensor. + * + * \return Logical 2D shape. + */ +NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor); + #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 314ba3b40f..e9e1d5bfb1 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -273,6 +273,128 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt CheckScaleTensorShape(t, name); } +void CheckGroupedTensorShapeArrays(const GroupedTensor &t, const std::string &name) { + NVTE_CHECK(t.num_tensors > 0, "Grouped tensor ", name, " has no tensors!"); + + // Helper lambda to validate shape arrays + // All three arrays are OPTIONAL: + // - first_dims: empty if all tensors have same first dimension + // - last_dims: empty if all tensors have same last dimension + // - tensor_offsets: empty if all tensors have same shape (offsets are predictable) + auto check_shape_array = [&](const SimpleTensor &arr, const char *arr_name) { + if (arr.has_data()) { + NVTE_CHECK(arr.shape.size() == 1, "Grouped tensor ", name, " ", arr_name, " must be 1D"); + NVTE_CHECK(arr.dtype == DType::kInt64, "Grouped tensor ", name, " ", arr_name, + " must have dtype Int64"); + NVTE_CHECK(arr.shape[0] == t.num_tensors, "Grouped tensor ", name, " ", arr_name, " size (", + arr.shape[0], ") must equal num_tensors (", t.num_tensors, ")"); + } + }; + + // Validate shape arrays (all optional) + check_shape_array(t.first_dims, "first_dims"); + check_shape_array(t.last_dims, "last_dims"); + check_shape_array(t.tensor_offsets, "tensor_offsets"); + + // tensor_offsets is required if any dimension varies + // (i.e., required unless all_same_shape()) + if (!t.all_same_shape()) { + NVTE_CHECK( + t.tensor_offsets.dptr != nullptr, "Grouped tensor ", name, + " must have tensor_offsets when any dimension varies (first_dims or last_dims is set)"); + } + + // Validate logical_shape + NVTE_CHECK(t.logical_shape.ndim == 2, "Grouped tensor ", name, " logical_shape must be 2D"); + NVTE_CHECK(t.logical_shape.data[0] > 0 && t.logical_shape.data[1] > 0, "Grouped tensor ", name, + " logical_shape must have positive dimensions"); + + // Validate all data fields are 1D (flattened) + if (t.has_data()) { + NVTE_CHECK(t.data.shape.size() == 1, "Grouped tensor ", name, " data must be 1D"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_data.shape.size() == 1, "Grouped tensor ", name, + " columnwise_data must be 1D"); + } + + // Validate data size matches logical_shape + size_t expected_numel = t.logical_shape.data[0] * t.logical_shape.data[1]; + if (t.has_data()) { + NVTE_CHECK(t.data.numel() == expected_numel, "Grouped tensor ", name, " data size (", + t.data.numel(), ") must match logical_shape size (", expected_numel, ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_data.numel() == expected_numel, "Grouped tensor ", name, + " columnwise_data size (", t.columnwise_data.numel(), + ") must match logical_shape size (", expected_numel, ")"); + } +} + +// Helper function to check scale_inv for both input and output +static void CheckGroupedScaleInv(const GroupedTensor &t, const std::string &name, bool is_output) { + const char *tensor_type = is_output ? "output" : "input"; + + // Helper to check scale_inv for both rowwise and columnwise layouts + auto check_scales = [&](DType expected_dtype) { + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.has_data(), tensor_type, " ", name, + " rowwise scale_inv must be allocated"); + NVTE_CHECK(t.scale_inv.dtype == expected_dtype, tensor_type, " ", name, + " rowwise scale_inv has invalid dtype (expected ", to_string(expected_dtype), + ", got ", to_string(t.scale_inv.dtype), ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.has_data(), tensor_type, " ", name, + " columnwise scale_inv must be allocated"); + NVTE_CHECK(t.columnwise_scale_inv.dtype == expected_dtype, tensor_type, " ", name, + " columnwise scale_inv has invalid dtype (expected ", to_string(expected_dtype), + ", got ", to_string(t.columnwise_scale_inv.dtype), ")"); + } + }; + + // Determine expected dtype based on data type and scaling mode + if (is_fp8_dtype(t.dtype()) && is_tensor_scaling(t.scaling_mode)) { + check_scales(DType::kFloat32); + } else if (is_mxfp8_scaling(t.scaling_mode)) { + check_scales(DType::kFloat8E8M0); + } else if (is_nvfp4_scaling(t.scaling_mode)) { + check_scales(DType::kFloat8E4M3); + } else { + // Non-quantized types should not have scale/scale_inv + NVTE_CHECK(!t.scale_inv.has_data(), "Scale_inv not supported for non-quantized ", tensor_type, + " ", name); + NVTE_CHECK(!t.columnwise_scale_inv.has_data(), "Scale_inv not supported for non-quantized ", + tensor_type, " ", name); + } +} + +void CheckInputGroupedTensor(const GroupedTensor &t, const std::string &name) { + NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input grouped tensor ", name, + " not allocated"); + CheckGroupedScaleInv(t, name, false); + CheckGroupedTensorShapeArrays(t, name); +} + +void CheckOutputGroupedTensor(const GroupedTensor &t, const std::string &name, bool allow_empty) { + if (!allow_empty) { + NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Output grouped tensor ", name, + " not allocated"); + } + + // Only perform dtype-specific validation if data is allocated + if (t.has_data() || t.has_columnwise_data()) { + // Amax validation for delayed scaling + if (is_fp8_dtype(t.dtype()) && t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + NVTE_CHECK(t.amax.has_data(), "Output ", name, " amax must be allocated"); + NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Output ", name, " amax must be Float32"); + } + CheckGroupedScaleInv(t, name, true); + } + + CheckGroupedTensorShapeArrays(t, name); +} + class TensorAllocator { public: static TensorAllocator &instance() { @@ -387,6 +509,89 @@ Tensor *convertNVTETensorCheck(const NVTETensor t) { return ptr; } +// GroupedTensor allocator - similar pattern to TensorAllocator +class GroupedTensorAllocator { + public: + static GroupedTensorAllocator &instance() { + static GroupedTensorAllocator allocator; + return allocator; + } + + ~GroupedTensorAllocator() {} + + NVTEGroupedTensor Allocate(NVTEScalingMode mode, size_t num_tensors, NVTEShape logical_shape) { + std::lock_guard lock(mutex); + if (!free_list.empty()) { + uintptr_t index = free_list.back(); + NVTEGroupedTensor ret = reinterpret_cast(index); + free_list.pop_back(); + // 1-based indexing - fully reinitialize the tensor to avoid stale data + memory[index - 1].scaling_mode = mode; + memory[index - 1].num_tensors = num_tensors; + memory[index - 1].logical_shape = logical_shape; + memory[index - 1].nvte_tensor = ret; + return ret; + } + if (memory.size() < memory.capacity()) { + memory.emplace_back(mode, num_tensors); + GroupedTensor &t = memory.back(); + size = memory.size(); + // 1-based indexing + uintptr_t index = memory.size(); + t.logical_shape = logical_shape; + t.nvte_tensor = reinterpret_cast(index); + return reinterpret_cast(index); + } + NVTE_ERROR( + "Cannot allocate a new NVTEGroupedTensor. Maximum number of grouped tensors reached: ", + MAX_GROUPED_TENSOR_NUM, ". There is probably a memory leak in your application."); + } + + void Free(NVTEGroupedTensor t) { + std::lock_guard lock(mutex); + uintptr_t index = reinterpret_cast(t); + if (index == 0) return; + NVTE_CHECK(index <= memory.size(), "Invalid grouped tensor."); + free_list.push_back(index); + // Clean up + memory[index - 1].clear(); + } + + GroupedTensor *convertNVTEGroupedTensor(NVTEGroupedTensor t) { + uintptr_t index = reinterpret_cast(t); + // 1-based indexing to enable 0-initialization of NVTEGroupedTensor + // to be invalid tensor + static_assert(nullptr == 0); + if (index != 0 && index <= size) { + return &(memory[index - 1]); + } + return nullptr; + } + + private: + GroupedTensorAllocator() { + std::lock_guard lock(mutex); + memory.reserve(MAX_GROUPED_TENSOR_NUM); + } + + std::mutex mutex; + std::atomic size; + // Allocate at most 20 MB for grouped tensors + const size_t MAX_GROUPED_TENSOR_NUM = 20 * 1024 * 1024 / sizeof(GroupedTensor); + std::vector free_list; + std::vector memory; +}; + +GroupedTensor *convertNVTEGroupedTensor(const NVTEGroupedTensor t) { + return GroupedTensorAllocator::instance().convertNVTEGroupedTensor(t); +} + +GroupedTensor *convertNVTEGroupedTensorCheck(const NVTEGroupedTensor t) { + GroupedTensor *ptr = GroupedTensorAllocator::instance().convertNVTEGroupedTensor(t); + NVTE_CHECK(ptr != nullptr, "Invalid grouped tensor."); + return ptr; +} + } // namespace transformer_engine NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) { @@ -730,3 +935,132 @@ int nvte_is_non_tn_fp8_gemm_supported() { }); return cache[device_id]; } + +// Grouped Tensor C API implementations +NVTEGroupedTensor nvte_create_grouped_tensor(NVTEScalingMode scaling_mode, size_t num_tensors, + NVTEShape logical_shape) { + NVTE_CHECK(num_tensors > 0, "Number of tensors must be greater than 0"); + NVTE_CHECK(logical_shape.ndim == 2, "Logical shape must be 2D"); + NVTE_CHECK(logical_shape.data[0] > 0 && logical_shape.data[1] > 0, + "Logical shape must have positive dimensions"); + NVTEGroupedTensor ret = transformer_engine::GroupedTensorAllocator::instance().Allocate( + scaling_mode, num_tensors, logical_shape); + return ret; +} + +void nvte_destroy_grouped_tensor(NVTEGroupedTensor tensor) { + transformer_engine::GroupedTensorAllocator::instance().Free(tensor); +} + +void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorParam param_name, + const NVTEBasicTensor *param) { + NVTE_CHECK(tensor != nullptr, "Grouped tensor pointer can't be NULL."); + auto *t = transformer_engine::convertNVTEGroupedTensor(*tensor); + NVTE_CHECK(t != nullptr, "Grouped tensor is not allocated."); + NVTE_CHECK(param != nullptr, "Grouped tensor param can't be NULL."); + + switch (param_name) { + case kNVTEGroupedRowwiseData: + t->data = *param; + break; + case kNVTEGroupedColumnwiseData: + t->columnwise_data = *param; + break; + case kNVTEGroupedScale: + t->scale = *param; + break; + case kNVTEGroupedAmax: + t->amax = *param; + break; + case kNVTEGroupedRowwiseScaleInv: + t->scale_inv = *param; + break; + case kNVTEGroupedColumnwiseScaleInv: + t->columnwise_scale_inv = *param; + break; + case kNVTEGroupedColumnwiseAmax: + t->columnwise_amax = *param; + break; + case kNVTEGroupedFirstDims: + t->first_dims = *param; + // Validate it's Int64 + NVTE_CHECK(t->first_dims.dtype == transformer_engine::DType::kInt64, + "first_dims must have dtype Int64"); + break; + case kNVTEGroupedLastDims: + t->last_dims = *param; + // Validate it's Int64 + NVTE_CHECK(t->last_dims.dtype == transformer_engine::DType::kInt64, + "last_dims must have dtype Int64"); + break; + case kNVTEGroupedTensorOffsets: + t->tensor_offsets = *param; + // Validate it's Int64 + NVTE_CHECK(t->tensor_offsets.dtype == transformer_engine::DType::kInt64, + "tensor_offsets must have dtype Int64"); + break; + default: + NVTE_ERROR("Unknown grouped tensor parameter!"); + } +} + +NVTEBasicTensor nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor, + NVTEGroupedTensorParam param_name) { + if (tensor == nullptr) { + return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)}; + } + const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); + + switch (param_name) { + case kNVTEGroupedRowwiseData: + return t.data; + case kNVTEGroupedColumnwiseData: + return t.columnwise_data; + case kNVTEGroupedScale: + return t.scale; + case kNVTEGroupedAmax: + return t.amax; + case kNVTEGroupedRowwiseScaleInv: + return t.scale_inv; + case kNVTEGroupedColumnwiseScaleInv: + return t.columnwise_scale_inv; + case kNVTEGroupedColumnwiseAmax: + return t.columnwise_amax; + case kNVTEGroupedFirstDims: + return t.first_dims; + case kNVTEGroupedLastDims: + return t.last_dims; + case kNVTEGroupedTensorOffsets: + return t.tensor_offsets; + default: + NVTE_ERROR("Unknown grouped tensor parameter!"); + } +} + +size_t nvte_grouped_tensor_num_tensors(const NVTEGroupedTensor tensor) { + auto *t = transformer_engine::convertNVTEGroupedTensor(tensor); + if (t == nullptr) return 0; + return t->num_tensors; +} + +NVTEDType nvte_grouped_tensor_type(const NVTEGroupedTensor tensor) { + auto *t = transformer_engine::convertNVTEGroupedTensor(tensor); + if (t == nullptr) return kNVTEFloat32; + return static_cast(t->dtype()); +} + +NVTEScalingMode nvte_grouped_tensor_scaling_mode(const NVTEGroupedTensor tensor) { + if (tensor == nullptr) { + return NVTE_DELAYED_TENSOR_SCALING; + } + const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); + return t.scaling_mode; +} + +NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor) { + if (tensor == nullptr) { + return nvte_make_shape(nullptr, 0); + } + const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); + return t.logical_shape; +}