Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.transpose seems to be mapped to different ops depending on requires_grad #1487

Open
crcrpar opened this issue Nov 27, 2024 · 2 comments
Labels

Comments

@crcrpar
Copy link
Collaborator

crcrpar commented Nov 27, 2024

Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.

🐛 Bug

I admit I'm not sure if this is a bug of an expected behavior, but torch.transpose is mapped to torch.transpose if requires_grad=False, torch.permute otherwise.

To Reproduce

Code sample

import torch
import thunder


@thunder.jit
def f(x):
    return x.transpose(0, 1)


for requires_grad in (False, True):
    print("#" * 120)
    x = torch.rand((4, 2), requires_grad=requires_grad)
    f(x)

    print(f"### {requires_grad = }")
    print(f"### thunder.last_traces(f)[-1]\n{thunder.last_traces(f)[-1]}")

output

########################################################################################################################
### requires_grad = False
### thunder.last_traces(f)[-1]
# Constructed by Unwrap the actual return value
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x):
  # x: "cpu f32[4, 2]"
  t0 = torch.transpose(x, 0, 1)  # t0: "cpu f32[2, 4]"
    # t0 = ltorch.transpose(x, 0, 1)  # t0: "cpu f32[2, 4]"
      # t0 = prims.transpose(x, (1, 0))  # t0: "cpu f32[2, 4]"
  return t0
########################################################################################################################
### requires_grad = True
### thunder.last_traces(f)[-1]
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x):
  # x: "cpu f32[4, 2]"
  t0 = torch.permute(x, (1, 0))  # t0: "cpu f32[2, 4]"
    # t0 = ltorch.permute(x, (1, 0))  # t0: "cpu f32[2, 4]"
      # t0 = prims.transpose(x, (1, 0))  # t0: "cpu f32[2, 4]"
  return {'output': t0, 'flat_args': [x], 'flat_output': (t0,)}, ((), ())

Expected behavior

Environment

  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

@mruberry
Copy link
Collaborator

mruberry commented Nov 27, 2024

I think what's happening here is that there's no grad rule for torch.transpose defined, so the grad decomposes it to a prims.transpose call, which then has a grad rule defined here:

def _transpose_prim_grad(a: TensorProxy, permutation: tuple[int, ...]) -> TensorProxy:

And then prims.transpose maps to torch.permute in the Torch executor here:

def _transpose_prim_transform(a: TensorProxy, /, permutation: Sequence[int]) -> TensorLike:

Because prims.transpose is consistent with jax.lax.transpose and torch.permute, which is a little confusing that not everyone agreed on what transpose and permute should mean.

One way to address this oddity might be to add a comment with the provenance of the final operation. Another idea, which we've discussed for awhile, would be to add more custom grad rules for torch operations, so that, for example, torch.transpose would just have its own grad rule that would also call torch.transpose. This might even improve performance, since custom grad rules can be more efficient than autogenerated ones.

fyi @IvanYashchuk; maybe @beverlylytle would be interested in identifying speed or memory differences between the autograds that Thunder generates and PyTorch eager's (and torch.compiles?)? Maybe instead of identifying the differences it's more interesting to just add some more autograd formulas. An automated mechanism to identify speed and memory differences would probably help us identify the most important operations to cover, however, and make it easy to talk about the improvement after a custom formula was added. It would also be important to verify that the custom handwritten formula actually had better performance, too.

Edit: Forgot to mention, if @beverlylytle and @IvanYashchuk decide to prioritize this, then I'm happy to talk about autograd concepts and autograd in Thunder and how we might systematically measure performance.

@IvanYashchuk
Copy link
Collaborator

We have a simple way to measure speed and memory differences using pytest-benchmark (speed is printed in the terminal output, memory is not but it's saved in the benchmark results file). Here's an example of gelu:

# pytest thunder/benchmarks/targets.py -k "test_litgpt_gelu" --benchmark-group-by='param:config,param:bs'

The challenge is determining the appropriate input shape for comparison. This is easy for elementwise operations but more difficult for operations like scaled dot product attention. When we have concrete shapes to analyze (for example from logs of ThunderFX recorded in SubgraphInfo) then we can create benchmarks per each PyTorch operation from the input graph similar to per graph benchmarking in
This class acts as a backend for the :func:`torch.compile` function, facilitating the benchmarking of each :class:`torch.fx.GraphModule` produced by Thunder dynamo splitter.
Each :class:`torch.fx.GraphModule` instance is executed by the specified executors and benchmarked using `pytest-benchmark`.
Args:
bench: the BenchmarkFixture created by ``pytest_benchmark``
executors: A dictionary of functors to compare.
- Key: The name of the executor to be displayed in the test name.
- Value: A callable representing the compile function to be applied to the GraphModule.
If the value is None, no compilation is performed, and the GraphModule runs in Torch eager mode.
Example:
.. code-block:: python
# script.py
import torch
import thunder
from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking
def func(x):
x = torch.sin(x)
if x.sum() > 0:
return x + 1
else:
return x - 1
def test_func(benchmark):
backend = ThunderCompilerGraphBenchmarking(benchmark, executors={"eager": None, "thunder": thunder.jit})
compiled = torch.compile(backend=backend)(func)
x = torch.ones(2, requires_grad=True).cuda()
compiled(x)
Note:
Ensure the pytest configuration file (`thunder/tests/conftest.py`) is present in the same directory as `script.py` to provide the grouping customization.
To run the benchmark test and group the results by split module, execute the following command:
`pytest script.py --benchmark-group-by='graph-by-graph:param:GraphID,param:SplitModuleName'`
In this example, Dynamo segments the graph into two subgraphs, each identified by the 'GraphID[id]' field in the test name.
Each subgraph contains a single split module, processed by the Thunder-defined splitter,
which corresponds to the 'SplitModuleName[split_module_name]' field.
The currently active executor is indicated by the 'executor[executor_name]'.
With `--benchmark-group-by='graph-by-graph:param:GraphID,param:SplitModuleName'`, the test cases are grouped based on GraphID and SplitModuleName,
allowing for performance comparison between different executors (e.g., 'eager' vs. 'thunder').
"""

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants