-
Notifications
You must be signed in to change notification settings - Fork 565
[Common] NVTEGroupedTensor class and helpers #2388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
801936c
36fb4c0
3df4d2f
b798042
7731134
8bb91c6
8628579
1e3921a
116d907
fb0e77b
58f5b2a
5420c72
bc90cac
070ba8e
981ac84
938ec98
4424b7c
1295d8e
69bd334
8a35897
a490bd1
b1c4f68
52f88f9
0e270c5
9457f2c
b701cc9
ad01ee6
d96da24
0b4fd40
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -393,6 +393,114 @@ int nvte_is_non_tn_fp8_gemm_supported(); | |||||||
| */ | ||||||||
| void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream); | ||||||||
|
|
||||||||
| /*! \brief TE Grouped Tensor type | ||||||||
| * | ||||||||
| * NVTEGroupedTensor is a collection of tensors with potentially different shapes | ||||||||
| * but the same dtype and scaling mode. It does not own the memory it points to. | ||||||||
| */ | ||||||||
| typedef void *NVTEGroupedTensor; | ||||||||
|
|
||||||||
| /*! \enum NVTEGroupedTensorParam | ||||||||
| * \brief Indicates the kind of the grouped tensor parameter to set/get. | ||||||||
| */ | ||||||||
| enum NVTEGroupedTensorParam { | ||||||||
| kNVTEGroupedRowwiseData = 0, /*!< Data usable in rowwise manner */ | ||||||||
| kNVTEGroupedColumnwiseData = 1, /*!< Data usable in columnwise manner */ | ||||||||
| kNVTEGroupedScale = 2, /*!< Scale tensor */ | ||||||||
| kNVTEGroupedAmax = 3, /*!< Amax tensor */ | ||||||||
| kNVTEGroupedRowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */ | ||||||||
| kNVTEGroupedColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */ | ||||||||
| kNVTEGroupedColumnwiseAmax = 6, /*!< Columnwise Amax tensor */ | ||||||||
| kNVTEGroupedFirstDims = 7, /*!< First dimension sizes (device pointer to int64_t array) */ | ||||||||
| kNVTEGroupedLastDims = 8, /*!< Last dimension sizes (device pointer to int64_t array) */ | ||||||||
| kNVTEGroupedTensorOffsets = | ||||||||
| 9, /*!< Tensor offsets for contiguous layout (device pointer to int64_t array) */ | ||||||||
| kNVTENumGroupedTensorParams | ||||||||
| }; | ||||||||
|
|
||||||||
| /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ | ||||||||
| /*! \brief Create a new TE grouped tensor. | ||||||||
| * | ||||||||
| * Create a new TE grouped tensor. Before use its parameters need to be set. | ||||||||
| * TE grouped tensors are just wrappers on top of raw data and do not | ||||||||
| * own memory. | ||||||||
| * | ||||||||
| * \param[in] scaling_mode Scaling mode of the grouped tensor. | ||||||||
| * \param[in] num_tensors Number of tensors in the group (must be > 0). | ||||||||
| * \param[in] logical_shape Logical 2D shape of the grouped data. | ||||||||
| * | ||||||||
| * \return A new TE grouped tensor. | ||||||||
| */ | ||||||||
| NVTEGroupedTensor nvte_create_grouped_tensor(NVTEScalingMode scaling_mode, size_t num_tensors, | ||||||||
| NVTEShape logical_shape); | ||||||||
phu0ngng marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
|
|
||||||||
| /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ | ||||||||
| /*! \brief Destroy a TE grouped tensor. | ||||||||
| * | ||||||||
| * Since the TE grouped tensor does not own memory, the underlying | ||||||||
| * data is not freed during this operation. | ||||||||
| * | ||||||||
| * \param[in] tensor Grouped tensor to be destroyed. | ||||||||
| */ | ||||||||
| void nvte_destroy_grouped_tensor(NVTEGroupedTensor tensor); | ||||||||
|
|
||||||||
| /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ | ||||||||
| /*! \brief Set a parameter of the grouped tensor. | ||||||||
| * | ||||||||
| * \param[in/out] tensor Grouped tensor. | ||||||||
| * \param[in] param_name The parameter to be set. | ||||||||
| * \param[in] param The value to be set (NVTEBasicTensor). | ||||||||
| */ | ||||||||
| void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorParam param_name, | ||||||||
| const NVTEBasicTensor *param); | ||||||||
|
Comment on lines
+454
to
+455
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This works when we're setting basic tensors, but doesn't generalize to other types ( TransformerEngine/transformer_engine/common/include/transformer_engine/transformer_engine.h Lines 369 to 371 in f8cb598
This is completely general, but also more cumbersome.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, I don't think we can go with a similar API, i.e., using just
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yet. |
||||||||
|
|
||||||||
| /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ | ||||||||
| /*! \brief Get a value of the parameter of the grouped tensor. | ||||||||
| * | ||||||||
| * \param[in] tensor Grouped tensor. | ||||||||
| * \param[in] param_name The parameter to be queried. | ||||||||
| * | ||||||||
| * \return NVTEBasicTensor containing the parameter data. | ||||||||
| */ | ||||||||
| NVTEBasicTensor nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor, | ||||||||
| NVTEGroupedTensorParam param_name); | ||||||||
|
|
||||||||
| /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ | ||||||||
| /*! \brief Get the number of tensors in a grouped tensor. | ||||||||
| * | ||||||||
| * \param[in] tensor Grouped tensor. | ||||||||
| * | ||||||||
| * \return Number of tensors in the group. | ||||||||
| */ | ||||||||
| size_t nvte_grouped_tensor_num_tensors(const NVTEGroupedTensor tensor); | ||||||||
|
|
||||||||
| /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ | ||||||||
| /*! \brief Get a grouped tensor's data type. | ||||||||
| * | ||||||||
| * \param[in] tensor Grouped tensor. | ||||||||
| * | ||||||||
| * \return A data type of the grouped tensor. | ||||||||
| */ | ||||||||
| NVTEDType nvte_grouped_tensor_type(const NVTEGroupedTensor tensor); | ||||||||
|
|
||||||||
| /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ | ||||||||
| /*! \brief Get a scaling mode of the grouped tensor. | ||||||||
| * | ||||||||
| * \param[in] tensor Grouped tensor. | ||||||||
| * | ||||||||
| * \return Scaling mode of the grouped tensor. | ||||||||
| */ | ||||||||
| NVTEScalingMode nvte_grouped_tensor_scaling_mode(const NVTEGroupedTensor tensor); | ||||||||
|
|
||||||||
| /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ | ||||||||
| /*! \brief Get the logical shape of a grouped tensor. | ||||||||
| * | ||||||||
| * \param[in] tensor Grouped tensor. | ||||||||
| * | ||||||||
| * \return Logical 2D shape. | ||||||||
| */ | ||||||||
| NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor); | ||||||||
|
|
||||||||
| #ifdef __cplusplus | ||||||||
| } // extern "C" | ||||||||
|
|
||||||||
|
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a giant pile of variables is fine since we want to make progress on this quickly, but in the future we should consider refactoring to handle polymorphism more gracefully: