Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThunderFX: Interface UX #1529

Closed
kiya00 opened this issue Dec 9, 2024 · 6 comments · Fixed by #1535
Closed

ThunderFX: Interface UX #1529

kiya00 opened this issue Dec 9, 2024 · 6 comments · Fixed by #1535
Assignees
Labels
thunderfx for things that could be applicable to the dynamo+thunder frontend ux

Comments

@kiya00
Copy link
Collaborator

kiya00 commented Dec 9, 2024

It would be nice if instead of having to write

backend = ThunderCompiler()
opt_foo1 = torch.compile(foo, backend=backend)

We instead allowed practitioners to write

from thunder import thunderfx

tfoo = thunderfx(foo)

Notebooks like this are a good opportunity to identify areas for improvement in our UX like this. We can create the thunderfx function so it accepts **kwargs to pass to the torch.compile call, too.

It'd be great to see a PR changing this so we can update this doc.


Reply via ReviewNB

Originally posted by @mruberry in #1524 (comment)

cc @Borda

@kiya00 kiya00 added ux thunderfx for things that could be applicable to the dynamo+thunder frontend labels Dec 9, 2024
@kiya00
Copy link
Collaborator Author

kiya00 commented Dec 9, 2024

As @mruberry has suggested, we can wrap it like this, WDYT @t-vi @IvanYashchuk

def thunderfx(fn: Callable, thunder_options: dict, **kwargs):
    backend = ThunderCompiler(thunder_options)
    return torch.compile(fn, **kwargs)

thunder_options corresponds to the constructor options of ThunderCompiler

@mruberry
Copy link
Collaborator

mruberry commented Dec 9, 2024

Note from triage today: we should follow-up with thinking about how we explain the thunderfx vs. thunder.jit entrypoints.

I filed #1532 to track that issue, which we can separate from these proposed changes.

@mruberry
Copy link
Collaborator

mruberry commented Dec 9, 2024

As @mruberry has suggested, we can wrap it like this, WDYT @t-vi @IvanYashchuk

def thunderfx(fn: Callable, thunder_options: dict, **kwargs):
    backend = ThunderCompiler(thunder_options)
    return torch.compile(fn, **kwargs)

thunder_options corresponds to the constructor options of ThunderCompiler

This is a good idea. What about the following tweak?

def thunderfx(fn: Callable, /, **kwargs):
  # Separate torch.compile kwargs from thunder kwargs here
  # To separate them, we could do something like this:
  
  torch_compile_kwarg_names = inspect.getfullargspec(torch.compile).kwonlyargs

  torch_compile_options = {k:v for k:v in kwargs if k in torch_compile_kwarg_names}
  thunder_options = {k:v for k:v in kwargs if k not in torch_compile_kwarg_names}

  # Change ThunderCompiler to take thunder_options and torch_compile_options separately and as dicts
  backend = ThunderCompiler(thunder_options, torch_compile_options)
  return torch.compile(fn, backend=backend, **torch_compile_options)

That would let all arguments be specified like regular kwargs, and I think it would slightly cleanup the ThunderCompiler dunder init function, too. This approach might have some issues in the future because:

  1. torch.compile could start accepting kwargs, so we'd have to send extra keywords to both torch.compile and thunder and hope they don't overlap
  2. torch.compile could add new keyword arguments that overlap with thunder keyword arguments, so we'd have to detect if an argument applies to both and send it to both as appropriate or change the UX again

I think these are OK problems to have, and we can address them if they come up in the future.

One refinement of this scheme would be to explicitly list known kwargs in the signature, so the signature could include kwargs for "fullgraph" and thunder options, for example. That would make the signature more readable if anyone looked at it. Then the ThunderCompiler dunder init could also accept some kwargs explicitly, which might be clearer.

@kshitij12345
Copy link
Collaborator

One thing to note is that ThunderCompiler object itself holds the Dynamo Subgraphs and other details useful for debugging in SubgraphInfo. So, when someone wants to debug or inspect the splits or grab the trace for the thunder function/s for that Subgraph, they currently require access to the ThunderCompiler object.

So the UX should consider this object as well.

Few options I can think of -

  1. Return a future-like object which will be populated with subgraph info details - so the signature can become something like thunderfx(fn, /, **kwargs) -> tuple[Callable, SubgraphInfoFutureObject]. This object is not ready till the first call to the returned Callable.
  2. Similar to what we have to grab traces (thunder.last_traces), we can attach the SubgraphInfo to the Callable returned from thunderfx function and then have a helper function get_subgraph_info which will pull this data if available.

Irrespective of the approach, I think my main point is to make sure that this information is still easily accessible.

@dataclasses.dataclass(frozen=True)
class SubgraphInfo:
"""A dataclass containing information about a subgraph.
Attributes:
original_graph_module: The original graph module.
original_split_graph_module: The original split graph module before any transformations are applied.
Specifically, before the :func:`checkpoint_converter` replaces the Torch operators with Thunder symbols,
and before any submodules are compiled by Thunder.
split_graph_module: The graph module for the split subgraph. It contains the compiled thunder/inductor modules.
thunder_compiled_fns: List of thunder optimized callables.
This could be :obj:`None` if there the graph module was not supported by thunder.
Look at the :attr:`split_reasons` for further information.
thunder_compiled_fns_example_inputs: List containing metadata of sample inputs for `thunder_compiled_fns`.
These inputs are used to generate random test inputs in the reproducer script.
submodule_to_compiled_functions: Dict from subgraph in :attr:`original_split_graph_module` to compiled function.
This will be a dict with one pair in case the graph was not split.
split_reasons: List of reasons explaining why the subgraph was split.
Present only if there are was a split.
"""

kiya00 added a commit that referenced this issue Dec 10, 2024
kiya00 added a commit that referenced this issue Dec 10, 2024
@kiya00 kiya00 mentioned this issue Dec 10, 2024
4 tasks
@kiya00
Copy link
Collaborator Author

kiya00 commented Dec 10, 2024

Yes, ThunderCompiler has SubgraphInfo attribute and save_reproducer_to_folder method so far to provide some debugging information. I was trying to simply add a _backend attribute (#1535 ) to the returned compiled callable so that the advanced user can have full access to ThunderCompiler. Or maybe we want a bunch of interfaces instead? like get_subgraph_infos, save_reproducer

Also I noticed in some cases users use it like:

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, get_scheduler
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
be = thunder.dynamo.ThunderCompiler(transforms=xforms)
model.compile(backend=be)

and it seems in this case it's easier to use ThunderCompiler

@mruberry
Copy link
Collaborator

Yes, ThunderCompiler has SubgraphInfo attribute and save_reproducer_to_folder method so far to provide some debugging information. I was trying to simply add a _backend attribute (#1535 ) to the returned compiled callable so that the advanced user can have full access to ThunderCompiler. Or maybe we want a bunch of interfaces instead? like get_subgraph_infos, save_reproducer

Also I noticed in some cases users use it like:

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, get_scheduler
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
be = thunder.dynamo.ThunderCompiler(transforms=xforms)
model.compile(backend=be)

and it seems in this case it's easier to use ThunderCompiler

Adding attributes to the returned callable, or creating new methods to access this information, both sound like great ideas that we can follow-up with in the future.

It would also be nice to consider how those functions compare to the existing functions for accessing thunder.jit data, like:

def last_traces(fn) -> list[TraceCtx]:

But we can also work on aligning these options later, too

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
thunderfx for things that could be applicable to the dynamo+thunder frontend ux
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants