-
Notifications
You must be signed in to change notification settings - Fork 565
[Common] NVTEGroupedTensor class and helpers #2388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
|
LGTM as discussed offline. Will be better if we can have some example usage of the new API. Otherwise, the new checkGroupOutputTensor seems complicated and I am not sure if it's too strict. |
transformer_engine/common/common.h
Outdated
| // [TODO] Discuss whether the first_dims and second_dims should be according to layout N | ||
| // Shape information: first_dims[i] and second_dims[i] define the shape of the i-th tensor | ||
| // For 2D tensors: shape[i] = (first_dims[i], second_dims[i]) | ||
| SimpleTensor first_dims; // Device pointer to size_t array of length num_tensors | ||
| SimpleTensor second_dims; // Device pointer to size_t array of length num_tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way I think about it is:
- we need to standardize first on which direction is used to say what is "first" and what is "second" (I prefer "last" BTW) -> I vote for rowwise
- then my thinking is that if the rowwise allocation has shape [m, k], then existence (or not) of those shapes would tell us which of the dimension is constant (e.g. second_dim being noninitialized would mean that all tensors are of shape [m_i, k]), which could be used for additional optimizations (e.g. via specialized kernel choice).
Alternatively the "reference" shape could be a property of the GroupedTensor itself just to avoid setting the shape on otherwise uninitialized rowwise tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I'm interpreting these dims as the logical tensor dims, which matches the row-wise data dims. Logical dims are completely independent on the data format, regardless of whether the column-wise data is transposed or not.
- I like the idea of the grouped tensor holding the "reference" shape and using it depending on whether
first_dim/second_dimare empty. I don't think we can rely on the shape of the data tensors since they might need to be flattened to 1D, e.g. FP8 transpose when splitting along the first dim. - Being able to split along 2 dims makes this very general. However, for MoE we always split along the first logical dim (PyTorch has column-major weights and JAX has row-major, so we need to swap the usage of row-wise and column-wise data. However, we are still splitting along the first logical dim). We should decide whether future-proofing is worth the extra complexity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, on second thought I think my last bullet point is still true. For MoE, we are always splitting along the first logical dim. Internally we might be splitting a transposed matrix along the last dim, but that is a detail within the group tensor class, and shouldn't be exposed in the public API.
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
| SimpleTensor data; | ||
| SimpleTensor columnwise_data; | ||
| SimpleTensor scale_inv; | ||
| SimpleTensor columnwise_scale_inv; | ||
| SimpleTensor amax; | ||
| SimpleTensor columnwise_amax; | ||
| SimpleTensor scale; // for FP8-DS only |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a giant pile of variables is fine since we want to make progress on this quickly, but in the future we should consider refactoring to handle polymorphism more gracefully:
// Visitor pattern
struct GroupedTensor {
public:
struct FP8Data {
std::optional<SimpleTensor> data;
std::optional<SimpleTensor> transpose;
SimpleTensor scale_inv;
std::optional<SimpleTensor> amax;
std::optional<SimpleTensor> scale;
};
struct MXFP8Data {
std::optional<std::tuple<SimpleTensor, SimpleTensor>> rowwise_data_and_scale;
std::optional<std::tuple<SimpleTensor, SimpleTensor>> columnwise_data_and_scale;
}
std::variant<FP8Data, MXFP8Data, ...> data;
};
// Inheritance pattern
struct GroupedTensor { ... };
struct FP8GroupedTensor : public GroupedTensor {
std::optional<SimpleTensor> data;
std::optional<SimpleTensor> transpose;
SimpleTensor scale_inv;
std::optional<SimpleTensor> amax;
std::optional<SimpleTensor> scale;
};
struct MXFP8GroupedTensor : public GroupedTensor {
std::optional<std::tuple<SimpleTensor, SimpleTensor>> rowwise_data_and_scale;
std::optional<std::tuple<SimpleTensor, SimpleTensor>> columnwise_data_and_scale;
};
transformer_engine/common/common.h
Outdated
| // [TODO] Discuss whether the first_dims and second_dims should be according to layout N | ||
| // Shape information: first_dims[i] and second_dims[i] define the shape of the i-th tensor | ||
| // For 2D tensors: shape[i] = (first_dims[i], second_dims[i]) | ||
| SimpleTensor first_dims; // Device pointer to size_t array of length num_tensors | ||
| SimpleTensor second_dims; // Device pointer to size_t array of length num_tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I'm interpreting these dims as the logical tensor dims, which matches the row-wise data dims. Logical dims are completely independent on the data format, regardless of whether the column-wise data is transposed or not.
- I like the idea of the grouped tensor holding the "reference" shape and using it depending on whether
first_dim/second_dimare empty. I don't think we can rely on the shape of the data tensors since they might need to be flattened to 1D, e.g. FP8 transpose when splitting along the first dim. - Being able to split along 2 dims makes this very general. However, for MoE we always split along the first logical dim (PyTorch has column-major weights and JAX has row-major, so we need to swap the usage of row-wise and column-wise data. However, we are still splitting along the first logical dim). We should decide whether future-proofing is worth the extra complexity.
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
|
/te-ci L0 |
Greptile OverviewGreptile SummaryThis PR introduces Key additions:
Previous review comments addressed: Confidence Score: 3/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant C_API as C API Layer
participant Allocator as GroupedTensorAllocator
participant Memory as Vector<GroupedTensor>
participant Validator as Validation Functions
User->>C_API: nvte_create_grouped_tensor(mode, num_tensors, logical_shape)
C_API->>C_API: Validate num_tensors > 0
C_API->>C_API: Validate logical_shape 2D and positive
C_API->>Allocator: Allocate(mode, num_tensors, logical_shape)
alt Free list not empty
Allocator->>Allocator: Pop index from free_list
Allocator->>Memory: memory[index-1].clear()
Allocator->>Memory: Set scaling_mode, num_tensors, logical_shape
else New allocation needed
Allocator->>Memory: emplace_back(mode, num_tensors)
Allocator->>Allocator: Update atomic size variable
Allocator->>Memory: Set logical_shape
end
Allocator-->>C_API: Return NVTEGroupedTensor (index as void*)
C_API-->>User: Return tensor handle
User->>C_API: nvte_set_grouped_tensor_param(tensor, param, data)
C_API->>Allocator: convertNVTEGroupedTensor(tensor)
Note over Allocator: Race condition risk: reads atomic size<br/>without mutex, accesses memory vector
Allocator-->>C_API: Return GroupedTensor*
C_API->>Memory: Set parameter (data, first_dims, etc.)
C_API-->>User: Parameter set
User->>Validator: CheckInputGroupedTensor(tensor)
Validator->>Validator: Check has_data() or has_columnwise_data()
Validator->>Validator: CheckGroupedScaleInv()
Validator->>Validator: CheckGroupedTensorShapeArrays()
Note over Validator: Validates:<br/>- Shape arrays (first_dims, last_dims, tensor_offsets)<br/>- Logical shape is 2D<br/>- Data size matches logical_shape<br/>- Scale/scale_inv dtypes
Validator-->>User: Validation result
User->>C_API: nvte_destroy_grouped_tensor(tensor)
C_API->>Allocator: Free(tensor)
Allocator->>Memory: memory[index-1].clear()
Allocator->>Allocator: Push index to free_list
C_API-->>User: Tensor destroyed
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 2 comments
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
| void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorParam param_name, | ||
| const NVTEBasicTensor *param); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works when we're setting basic tensors, but doesn't generalize to other types (bool/float/etc). Consider using a more general API like how we handle NVTEQuantizationConfig:
TransformerEngine/transformer_engine/common/include/transformer_engine/transformer_engine.h
Lines 369 to 371 in f8cb598
| void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, | |
| NVTEQuantizationConfigAttribute attr, const void *buf, | |
| size_t size_in_bytes); |
This is completely general, but also more cumbersome.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I don't think we can go with a similar API, i.e., using just void* buf and size_in_bytes as we do need different dtype for different fields.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yet.
transformer_engine/common/include/transformer_engine/transformer_engine.h
Show resolved
Hide resolved
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 2 comments
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
|
/te-ci L0 |
timmoon10
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM, but this is hitting PyTorch test failures due to changes in Tensor::has_data/Tensor::has_columnwise_data. We should merge #2330 first.
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
|
/te-ci L0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
| NVTEGroupedTensor ret = reinterpret_cast<NVTEGroupedTensor>(index); | ||
| free_list.pop_back(); | ||
| // 1-based indexing - fully reinitialize the tensor to avoid stale data | ||
| memory[index - 1].clear(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we clear it again?
timmoon10
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
NVTEGroupedTensor class and helpers
Type of change
Checklist: