diff --git a/docs/source/reference/transforms/index.rst b/docs/source/reference/transforms/index.rst index 8711275e14..0e12e3a04b 100644 --- a/docs/source/reference/transforms/index.rst +++ b/docs/source/reference/transforms/index.rst @@ -6,5 +6,6 @@ thunder.transforms .. autosummary:: :toctree: generated/ + flatten_tensor_subclasses MaterializationTransform ConstantFolding diff --git a/thunder/transforms/__init__.py b/thunder/transforms/__init__.py index 2ae556c2a0..ee4d72249e 100644 --- a/thunder/transforms/__init__.py +++ b/thunder/transforms/__init__.py @@ -1,9 +1,11 @@ from .constant_folding import ConstantFolding from .materialization import MaterializationTransform from .qlora import LORATransform +from .tensor_subclasses import flatten_tensor_subclasses __all__ = [ + "flatten_tensor_subclasses", "ConstantFolding", "LORATransform", "MaterializationTransform",