Skip to content

Commit

Permalink
[docs] thunder.transforms.ConstantFolding (#1466)
Browse files Browse the repository at this point in the history
  • Loading branch information
crcrpar authored Dec 2, 2024
1 parent 2baaeb7 commit 44e98ae
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/reference/transforms/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ thunder.transforms
:toctree: generated/

MaterializationTransform
ConstantFolding
4 changes: 3 additions & 1 deletion thunder/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .constant_folding import ConstantFolding
from .materialization import MaterializationTransform
from .qlora import LORATransform


__all__ = [
"MaterializationTransform",
"ConstantFolding",
"LORATransform",
"MaterializationTransform",
]
72 changes: 72 additions & 0 deletions thunder/transforms/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
from thunder.torch import _torch_to_thunder_function_map
from thunder.core.utils import get_symbols_to_last_used_variables


__all__ = [
"ConstantFolding",
]


_thunder_to_torch_function_map = {v: k for k, v in _torch_to_thunder_function_map.items()}

# Factory functions whose value we know.
Expand Down Expand Up @@ -70,6 +76,72 @@ def materialize_args(a):


class ConstantFolding(thunder.Transform):
"""Apply Constant Folding to computation trace.
With this transform applied to a computation trace, successive passes
(meaning trace transformations) can transform the simplified compute.
.. code-block:: python
:name: example-constant_folding
from thunder.transforms import ConstantFolding
model = ...
transforms = [ConstantFolding()]
jitted = thunder.jit(model, transforms=transforms)
# If you prefer `ThunderCompiler`...
from thunder.dynamo import ThunderCompiler
backend = ThunderCompiler(transforms=transforms)
jitted = torch.compile(model, backend=backend)
To see the effect of this transform, let's use the following function:
.. code-block:: python
def forward(x):
scale_t = torch.tensor([2.])
scale_t = (scale_t * 10) / 5
return x * scale_t
The initial computation trace is as follows:
.. code-block:: python
def computation(x):
# x: "cpu f32[3]"
scale_t = ltorch.tensor([2.0], device=None, dtype=None, requires_grad=False, pin_memory=False) # scale_t: "cpu f32[1]"
# scale_t = prims.tensor_from_sequence([2.0], dtype=None, device=devices.Device("cpu")) # scale_t: "cpu f32[1]"
t1 = ltorch.mul(scale_t, 10) # t1: "cpu f32[1]"
# _ = prims.convert_element_type(10, float)
# t1 = prims.mul(scale_t, 10.0) # t1: "cpu f32[1]"
t2 = ltorch.true_divide(t1, 5) # t2: "cpu f32[1]"
# _ = prims.convert_element_type(5, float)
# t2 = prims.div(t1, 5.0) # t2: "cpu f32[1]"
t4 = ltorch.mul(x, t2) # t4: "cpu f32[3]"
# t3 = prims.broadcast_in_dim(t2, (3,), (0,)) # t3: "cpu f32[3]"
# t4 = prims.mul(x, t3) # t4: "cpu f32[3]"
return t4
This transform simplifies this trace into
.. code-block:: python
def computation(x):
# x: "cpu f32[3]"
t2 = prims.tensor_from_sequence([4.0], dtype=dtypes.float32, device=devices.Device("cpu")) # t2: "cpu f32[1]"
t4 = ltorch.mul(x, t2) # t4: "cpu f32[3]"
# t3 = prims.broadcast_in_dim(t2, (3,), (0,)) # t3: "cpu f32[3]"
# t4 = prims.mul(x, t3) # t4: "cpu f32[3]"
return {'output': t4, 'flat_args': [x]}
"""

def transform_traces_pre_prologue(self, prologue_trc, computation_trc, epilogue_trc, **kwargs):
# Create a new trace
const_folded_trace = from_trace(computation_trc)
Expand Down

0 comments on commit 44e98ae

Please sign in to comment.