Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
801936c
add grouped_tensor classes and helpers
phu0ngng Nov 14, 2025
36fb4c0
rm non-contiguous option and dptrs
phu0ngng Nov 18, 2025
3df4d2f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2025
b798042
address comments + rework CheckIn/OutputGroupedTensor
phu0ngng Nov 20, 2025
7731134
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2025
8bb91c6
fix for compilation
phu0ngng Nov 20, 2025
8628579
make first_dims/last_dims optional + data.shape 2d
phu0ngng Nov 21, 2025
1e3921a
added assertion
phu0ngng Nov 21, 2025
116d907
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 21, 2025
fb0e77b
rs conflicts
phu0ngng Nov 24, 2025
58f5b2a
add data.shape info
phu0ngng Nov 24, 2025
5420c72
added logical shape field
phu0ngng Nov 24, 2025
bc90cac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2025
070ba8e
compilation fix
phu0ngng Nov 24, 2025
981ac84
Merge branch 'main' into nvte_grouped_tensor
phu0ngng Nov 24, 2025
938ec98
fixed issues raised by greptile
phu0ngng Nov 24, 2025
4424b7c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2025
1295d8e
return default dtype when grouped_tensor is empty
phu0ngng Nov 24, 2025
69bd334
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2025
8a35897
use has_data() for dim queries
phu0ngng Nov 25, 2025
a490bd1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 25, 2025
b1c4f68
update comments
phu0ngng Nov 26, 2025
52f88f9
fix index bound
phu0ngng Nov 26, 2025
0e270c5
Update transformer_engine/common/transformer_engine.cpp
phu0ngng Nov 26, 2025
9457f2c
Update transformer_engine/common/transformer_engine.cpp
phu0ngng Nov 26, 2025
b701cc9
restore Tensor.has_data() + add experimental marks
phu0ngng Nov 26, 2025
ad01ee6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 26, 2025
d96da24
restore Tensor::has_columnwise_data
phu0ngng Nov 26, 2025
0b4fd40
cleanup
phu0ngng Dec 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ struct SimpleTensor {
}
return acc;
}
bool has_data() const noexcept { return dptr != nullptr && numel() > 0; }

void clear() {
dptr = nullptr;
Expand Down Expand Up @@ -154,9 +155,11 @@ struct Tensor {
return acc;
}

// TODO(Tim): Change this to use data.has_data()
bool has_data() const noexcept { return data.dptr != nullptr; }

// Check for size (not just pointer) for 0-dim or no token cases.
// TODO(Tim): Change this to use columnwise_data.has_data()
bool has_columnwise_data() const noexcept {
return columnwise_data.dptr != nullptr || columnwise_data.shape.size() != 0;
}
Expand Down Expand Up @@ -281,6 +284,129 @@ struct Tensor {
}
};

struct GroupedTensor {
public:
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*
Grouped tensor is a collection of tensors with different shapes but the same dtype and scaling mode

Shape Representation:
- logical_shape: 2D shape representing the conceptual layouy, i.e. the shape when member tensors are flattened to 2D and stacked together (REQUIRED)
+ When all_same_shape(): [num_tensors * M, N] where each tensor is (M, N)
+ When varying_first_dim(): [~sum_of_first_dims, N] where N is common
+ When varying_last_dim(): [M, ~sum_of_last_dims] where M is common
+ When varying_both_dims(): [1, total_elements] (fully flattened)

- first_dims and last_dims are OPTIONAL (empty if dimension is uniform)
+ Empty first_dims: all tensors have the same first dimension
+ Empty last_dims: all tensors have the same last dimension
+ Both empty: all tensors have identical shapes
+ Both set: each tensor has unique shape (first_dims[i], last_dims[i])

Data Layout:
- ALL data fields are stored as 1D flattened arrays (data, columnwise_data, scale_inv, etc.)
- logical_shape provides the conceptual 2D interpretation
- All data is stored on device in contiguous layout
*/

SimpleTensor data;
SimpleTensor columnwise_data;
SimpleTensor scale_inv;
SimpleTensor columnwise_scale_inv;
SimpleTensor amax;
SimpleTensor columnwise_amax;
SimpleTensor scale; // for FP8-DS only
Comment on lines +312 to +318
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a giant pile of variables is fine since we want to make progress on this quickly, but in the future we should consider refactoring to handle polymorphism more gracefully:

// Visitor pattern
struct GroupedTensor {
 public:
  struct FP8Data {
    std::optional<SimpleTensor> data;
    std::optional<SimpleTensor> transpose;
    SimpleTensor scale_inv;
    std::optional<SimpleTensor> amax;
    std::optional<SimpleTensor> scale;
  };
  struct MXFP8Data {
    std::optional<std::tuple<SimpleTensor, SimpleTensor>> rowwise_data_and_scale;
    std::optional<std::tuple<SimpleTensor, SimpleTensor>> columnwise_data_and_scale;
  }
  std::variant<FP8Data, MXFP8Data, ...> data;
};

// Inheritance pattern
struct GroupedTensor { ... };
struct FP8GroupedTensor : public GroupedTensor {
    std::optional<SimpleTensor> data;
    std::optional<SimpleTensor> transpose;
    SimpleTensor scale_inv;
    std::optional<SimpleTensor> amax;
    std::optional<SimpleTensor> scale;
};
struct MXFP8GroupedTensor : public GroupedTensor {
    std::optional<std::tuple<SimpleTensor, SimpleTensor>> rowwise_data_and_scale;
    std::optional<std::tuple<SimpleTensor, SimpleTensor>> columnwise_data_and_scale;
};


// Shape information (OPTIONAL - empty if dimension is uniform across all tensors)
// first_dims[i] = first dimension of tensor i (empty if all tensors have same first dim)
// last_dims[i] = last dimension of tensor i (empty if all tensors have same last dim)
SimpleTensor first_dims; // Device pointer to int64_t array of length num_tensors (or empty)
SimpleTensor last_dims; // Device pointer to int64_t array of length num_tensors (or empty)

// Offsets for indexing into contiguous 1D layout (OPTIONAL - not needed if all_same_shape())
// tensor_offsets[i] = element offset to start of tensor i (cumulative sum of numel for tensors 0..i-1)
// Usage: tensor_i_ptr = (char*)data.dptr + tensor_offsets[i] * element_size
// If empty and all_same_shape(): offset[i] = i * M * N (where M, N are common dimensions)
SimpleTensor tensor_offsets; // Device pointer to int64_t array of length num_tensors (or empty)

// Logical shape: conceptual 2D shape of the grouped data (REQUIRED)
// Represents how the 1D flattened data should be interpreted as 2D
// Always 2D with positive dimensions
NVTEShape logical_shape;

NVTEScalingMode scaling_mode;
size_t num_tensors;
NVTEGroupedTensor nvte_tensor;

GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors)
: data(),
columnwise_data(),
scale_inv(),
columnwise_scale_inv(),
amax(),
columnwise_amax(),
scale(),
num_tensors(num_tensors),
first_dims(nullptr, {}, DType::kInt64),
last_dims(nullptr, {}, DType::kInt64),
tensor_offsets(nullptr, {}, DType::kInt64),
logical_shape(nvte_make_shape(nullptr, 0)),
scaling_mode(scaling_mode),
nvte_tensor(0) {}

explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; }

bool has_data() const noexcept { return data.has_data(); }
bool has_columnwise_data() const noexcept { return columnwise_data.has_data(); }

bool all_same_first_dim() const noexcept { return !first_dims.has_data(); }
bool all_same_last_dim() const noexcept { return !last_dims.has_data(); }
bool all_same_shape() const noexcept { return !first_dims.has_data() && !last_dims.has_data(); }
bool varying_both_dims() const noexcept { return first_dims.has_data() && last_dims.has_data(); }

size_t get_common_first_dim() const {
NVTE_CHECK(all_same_first_dim(), "First dim varies across tensors");
NVTE_CHECK(logical_shape.ndim == 2, "Logical shape must be 2D");
if (all_same_shape()) {
// When both dims are uniform: logical_shape = [num_tensors * M, N]
return logical_shape.data[0] / num_tensors;
} else {
// When varying last dims but not first dim: logical_shape = [M, sum_of_last_dims]
return logical_shape.data[0];
}
}
size_t get_common_last_dim() const {
NVTE_CHECK(all_same_last_dim(), "Last dim varies across tensors");
NVTE_CHECK(logical_shape.ndim == 2, "Logical shape must be 2D");
// For both uniform and varying first dim cases: logical_shape[1] is the common last dim
return logical_shape.data[1];
}

DType dtype() const {
if (has_data()) return data.dtype;
if (has_columnwise_data()) return columnwise_data.dtype;
// Fallback, used e.g. in workspace or when allow_empty=true
return data.dtype;
}

void clear() {
data.clear();
columnwise_data.clear();
scale_inv.clear();
columnwise_scale_inv.clear();
amax.clear();
columnwise_amax.clear();
scale.clear();
first_dims.clear();
last_dims.clear();
tensor_offsets.clear();
logical_shape = nvte_make_shape(nullptr, 0);
num_tensors = 0;
scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
nvte_tensor = 0;
}
};

struct QuantizationConfig {
bool force_pow_2_scales = false;
float amax_epsilon = 0.0f;
Expand Down Expand Up @@ -779,6 +905,16 @@ std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensor

Tensor *convertNVTETensor(const NVTETensor tensor);
Tensor *convertNVTETensorCheck(const NVTETensor tensor);

GroupedTensor *convertNVTEGroupedTensor(const NVTEGroupedTensor tensor);
GroupedTensor *convertNVTEGroupedTensorCheck(const NVTEGroupedTensor tensor);

// Helper functions for GroupedTensor validation
void CheckGroupedTensorShapeArrays(const GroupedTensor &t, const std::string &name);
void CheckInputGroupedTensor(const GroupedTensor &t, const std::string &name);
void CheckOutputGroupedTensor(const GroupedTensor &t, const std::string &name,
bool allow_empty = false);

} // namespace transformer_engine

#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,114 @@ int nvte_is_non_tn_fp8_gemm_supported();
*/
void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream);

/*! \brief TE Grouped Tensor type
*
* NVTEGroupedTensor is a collection of tensors with potentially different shapes
* but the same dtype and scaling mode. It does not own the memory it points to.
*/
typedef void *NVTEGroupedTensor;

/*! \enum NVTEGroupedTensorParam
* \brief Indicates the kind of the grouped tensor parameter to set/get.
*/
enum NVTEGroupedTensorParam {
kNVTEGroupedRowwiseData = 0, /*!< Data usable in rowwise manner */
kNVTEGroupedColumnwiseData = 1, /*!< Data usable in columnwise manner */
kNVTEGroupedScale = 2, /*!< Scale tensor */
kNVTEGroupedAmax = 3, /*!< Amax tensor */
kNVTEGroupedRowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */
kNVTEGroupedColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */
kNVTEGroupedColumnwiseAmax = 6, /*!< Columnwise Amax tensor */
kNVTEGroupedFirstDims = 7, /*!< First dimension sizes (device pointer to int64_t array) */
kNVTEGroupedLastDims = 8, /*!< Last dimension sizes (device pointer to int64_t array) */
kNVTEGroupedTensorOffsets =
9, /*!< Tensor offsets for contiguous layout (device pointer to int64_t array) */
kNVTENumGroupedTensorParams
};

/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Create a new TE grouped tensor.
*
* Create a new TE grouped tensor. Before use its parameters need to be set.
* TE grouped tensors are just wrappers on top of raw data and do not
* own memory.
*
* \param[in] scaling_mode Scaling mode of the grouped tensor.
* \param[in] num_tensors Number of tensors in the group (must be > 0).
* \param[in] logical_shape Logical 2D shape of the grouped data.
*
* \return A new TE grouped tensor.
*/
NVTEGroupedTensor nvte_create_grouped_tensor(NVTEScalingMode scaling_mode, size_t num_tensors,
NVTEShape logical_shape);

/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Destroy a TE grouped tensor.
*
* Since the TE grouped tensor does not own memory, the underlying
* data is not freed during this operation.
*
* \param[in] tensor Grouped tensor to be destroyed.
*/
void nvte_destroy_grouped_tensor(NVTEGroupedTensor tensor);

/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Set a parameter of the grouped tensor.
*
* \param[in/out] tensor Grouped tensor.
* \param[in] param_name The parameter to be set.
* \param[in] param The value to be set (NVTEBasicTensor).
*/
void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorParam param_name,
const NVTEBasicTensor *param);
Comment on lines +454 to +455
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works when we're setting basic tensors, but doesn't generalize to other types (bool/float/etc). Consider using a more general API like how we handle NVTEQuantizationConfig:

void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
NVTEQuantizationConfigAttribute attr, const void *buf,
size_t size_in_bytes);

This is completely general, but also more cumbersome.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I don't think we can go with a similar API, i.e., using just void* buf and size_in_bytes as we do need different dtype for different fields.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yet.


/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get a value of the parameter of the grouped tensor.
*
* \param[in] tensor Grouped tensor.
* \param[in] param_name The parameter to be queried.
*
* \return NVTEBasicTensor containing the parameter data.
*/
NVTEBasicTensor nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor,
NVTEGroupedTensorParam param_name);

/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get the number of tensors in a grouped tensor.
*
* \param[in] tensor Grouped tensor.
*
* \return Number of tensors in the group.
*/
size_t nvte_grouped_tensor_num_tensors(const NVTEGroupedTensor tensor);

/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get a grouped tensor's data type.
*
* \param[in] tensor Grouped tensor.
*
* \return A data type of the grouped tensor.
*/
NVTEDType nvte_grouped_tensor_type(const NVTEGroupedTensor tensor);

/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get a scaling mode of the grouped tensor.
*
* \param[in] tensor Grouped tensor.
*
* \return Scaling mode of the grouped tensor.
*/
NVTEScalingMode nvte_grouped_tensor_scaling_mode(const NVTEGroupedTensor tensor);

/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get the logical shape of a grouped tensor.
*
* \param[in] tensor Grouped tensor.
*
* \return Logical 2D shape.
*/
NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor);

#ifdef __cplusplus
} // extern "C"

Expand Down
Loading
Loading