Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add fullgraph argument for torch compile #1007

Merged
merged 1 commit into from
Aug 22, 2024

Conversation

k223kim
Copy link
Contributor

@k223kim k223kim commented Aug 21, 2024

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #281. As this is a fairly old issue, I am not sure if this is needed. Also, I am not sure how to add a test case for this.

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, would we want to add a test, say for llama2-like to run with this to test_networks or so?

@k223kim
Copy link
Contributor Author

k223kim commented Aug 21, 2024

I could, however, I don't think that will allow us to check if fullgraph has been modified. What do you think? Is there a way to check which parameter has been passed to torch.compile that I am not aware of?

@t-vi
Copy link
Collaborator

t-vi commented Aug 22, 2024

Right, I had hoped we could get it from the compiler config or so, but apparently it is None...

import torch, thunder

def fn(a):
     return a * 2 + 1
jfn = thunder.jit(fn, executors=[thunder.executors.torch_compile.torch_compile_ex])
jfn(torch.randn(2, 2))

lt = thunder.last_traces(jfn)[-1]
tc = lt.python_ctx()['TorchCompile0']
cc = tc.__closure__[0].cell_contents.get_compiler_config()
``

So let's keep it as is.

@kshitij12345
Copy link
Collaborator

I think a simple way to test fullgraph=True would be to create a graph with graph break. So, compiling this should throw an exception, which we can verify with pytest.raises.

import torch

def foo():
    print("This will lead to graph break")

torch.compile(foo, fullgraph=True)()  # This will throw `torch._dynamo.exc.Unsupported`

@k223kim
Copy link
Contributor Author

k223kim commented Aug 22, 2024

@kshitij12345 Hey! Thanks for the suggestion. I did try this but I don't think it'll will throw an error within thunder:

import torch
import thunder
from thunder.executors.torch_compile import torch_compile_ex

def foo():
    print("This will lead to graph break")

torch.compile(foo, fullgraph=True)()  # This will throw `torch._dynamo.exc.Unsupported` 
jfn = thunder.jit(foo, executors=[torch_compile_ex], torch_compile_fullgraph=True, torch_compile_backend="reduce-overhead")
jfn()# this will not throw an error

Which makes it difficult to see what parameter has been passed to torch.compile.. Do you have a different suggestion?

@k223kim k223kim marked this pull request as ready for review August 22, 2024 15:18
@kshitij12345
Copy link
Collaborator

kshitij12345 commented Aug 22, 2024

Which makes it difficult to see what parameter has been passed to torch.compile.. Do you have a different suggestion?

You are right, I forgot that the torch.compile is called on the trace from thunder, so it will receive only the parts it claims allowing it to compile the full graph it receives.

Thanks for trying this!

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a cool option

@mruberry mruberry enabled auto-merge (squash) August 22, 2024 17:09
@mruberry mruberry merged commit 6e1cb40 into Lightning-AI:main Aug 22, 2024
40 checks passed
@github-actions github-actions bot deleted the k223kim/torch_compile_arguments branch November 21, 2024 01:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Expose torch.compile arguments as compile options
4 participants