diff --git a/docs/source/reference/transforms/index.rst b/docs/source/reference/transforms/index.rst index 88eb26bb8c..8711275e14 100644 --- a/docs/source/reference/transforms/index.rst +++ b/docs/source/reference/transforms/index.rst @@ -7,3 +7,4 @@ thunder.transforms :toctree: generated/ MaterializationTransform + ConstantFolding diff --git a/thunder/transforms/__init__.py b/thunder/transforms/__init__.py index b1b9ad130d..2ae556c2a0 100644 --- a/thunder/transforms/__init__.py +++ b/thunder/transforms/__init__.py @@ -1,8 +1,10 @@ +from .constant_folding import ConstantFolding from .materialization import MaterializationTransform from .qlora import LORATransform __all__ = [ - "MaterializationTransform", + "ConstantFolding", "LORATransform", + "MaterializationTransform", ] diff --git a/thunder/transforms/constant_folding.py b/thunder/transforms/constant_folding.py index 24085793bc..e1994e1fd9 100644 --- a/thunder/transforms/constant_folding.py +++ b/thunder/transforms/constant_folding.py @@ -13,6 +13,12 @@ from thunder.torch import _torch_to_thunder_function_map from thunder.core.utils import get_symbols_to_last_used_variables + +__all__ = [ + "ConstantFolding", +] + + _thunder_to_torch_function_map = {v: k for k, v in _torch_to_thunder_function_map.items()} # Factory functions whose value we know. @@ -70,6 +76,72 @@ def materialize_args(a): class ConstantFolding(thunder.Transform): + """Apply Constant Folding to computation trace. + + With this transform applied to a computation trace, successive passes + (meaning trace transformations) can transform the simplified compute. + + + .. code-block:: python + :name: example-constant_folding + + from thunder.transforms import ConstantFolding + + model = ... + transforms = [ConstantFolding()] + jitted = thunder.jit(model, transforms=transforms) + # If you prefer `ThunderCompiler`... + from thunder.dynamo import ThunderCompiler + backend = ThunderCompiler(transforms=transforms) + jitted = torch.compile(model, backend=backend) + + + To see the effect of this transform, let's use the following function: + + .. code-block:: python + + def forward(x): + scale_t = torch.tensor([2.]) + scale_t = (scale_t * 10) / 5 + return x * scale_t + + The initial computation trace is as follows: + + .. code-block:: python + + def computation(x): + # x: "cpu f32[3]" + + scale_t = ltorch.tensor([2.0], device=None, dtype=None, requires_grad=False, pin_memory=False) # scale_t: "cpu f32[1]" + # scale_t = prims.tensor_from_sequence([2.0], dtype=None, device=devices.Device("cpu")) # scale_t: "cpu f32[1]" + + t1 = ltorch.mul(scale_t, 10) # t1: "cpu f32[1]" + # _ = prims.convert_element_type(10, float) + # t1 = prims.mul(scale_t, 10.0) # t1: "cpu f32[1]" + t2 = ltorch.true_divide(t1, 5) # t2: "cpu f32[1]" + # _ = prims.convert_element_type(5, float) + # t2 = prims.div(t1, 5.0) # t2: "cpu f32[1]" + + t4 = ltorch.mul(x, t2) # t4: "cpu f32[3]" + # t3 = prims.broadcast_in_dim(t2, (3,), (0,)) # t3: "cpu f32[3]" + # t4 = prims.mul(x, t3) # t4: "cpu f32[3]" + return t4 + + This transform simplifies this trace into + + .. code-block:: python + + def computation(x): + # x: "cpu f32[3]" + t2 = prims.tensor_from_sequence([4.0], dtype=dtypes.float32, device=devices.Device("cpu")) # t2: "cpu f32[1]" + + t4 = ltorch.mul(x, t2) # t4: "cpu f32[3]" + # t3 = prims.broadcast_in_dim(t2, (3,), (0,)) # t3: "cpu f32[3]" + # t4 = prims.mul(x, t3) # t4: "cpu f32[3]" + return {'output': t4, 'flat_args': [x]} + + """ + def transform_traces_pre_prologue(self, prologue_trc, computation_trc, epilogue_trc, **kwargs): # Create a new trace const_folded_trace = from_trace(computation_trc)