[PyTorch] Add record_stream and untyped_storage func op in QuantizedTensor #2144
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
In the FP8 dataflow (dispatch → expert fc1) path, under 1F1B overlap, we need to switch to the work stream and safely release memory from the previous stream when passing a QuantizedTensor (FP8 payload plus scale_inv metadata) from dispatch to expert fc1. This reduces peak HBM usage and avoids unnecessary memory retention.
Type of change
Changes
torch_dispatch handler for aten.record_stream on QuantizedTensor:
We record all relevant CUDA buffers inside the quantized tensor—_rowwise_data/_columnwise_data and their _rowwise_scale_inv/_columnwise_scale_inv—onto the provided stream via record_stream(stream). This does not change tensor values; it only updates storage lifetime metadata so the allocator won’t reuse/free the memory before the stream finishes its asynchronous work.
Expose QuantizedTensor.untyped_storage():
Returns the payload’s underlying UntypedStorage. Callers can then run resize_(0) to immediately shrink the storage capacity to zero and return it to the caching allocator (on CUDA).
Checklist: