Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add splitter to dynamo backend #5

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 26 additions & 23 deletions thunder/dynamo/compiler.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's create a new file splitter.py.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have created a new file utils.py as it has the code for dataclasses which are returned (which is not directly related to splitter) and splitter related functions.

Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from typing import List, Dict, Optional, Tuple, Set
from collections.abc import Callable
from functools import partial

import torch
from torch.fx.passes.split_module import split_module
import warnings
from collections.abc import Mapping

from thunder.core.baseutils import run_once

from thunder.dynamo.utils import SubgraphInfo
from thunder.dynamo.splitter import _splitter


@run_once
def _warn_thunder_compiler():
Expand All @@ -20,7 +29,9 @@ def __init__(self, **thunder_options):
function.

Keyword arguments:
thunder_options: a dictionary of options to pass to `thunder.jit`.
thunder_options: a dictionary of options to pass to `thunder.jit`. Besides all the arguments to `thunder.jit`,
it accepts `torch_inductor_options` which are passed to `torch.compile` if part of the graph
is not supported by thunder.

Example:
>>> import torch
Expand All @@ -36,38 +47,30 @@ def __init__(self, **thunder_options):
... return x - 1
>>> out = func(x)
"""
from thunder import ThunderModule
from thunder import ThunderModule, jit

_warn_thunder_compiler()

# Thunder-compiled functions should be readily available for inspection
# and testing, so we will store them in a list. The order of the
# and testing, so we will store them in a list[SubgraphInfo]. The order of the
# functions in the list will be the same as the order in which they were
# compiled. In addition, we will store a mapping from the ThunderModule
# to the GraphModule that was passed to ThunderCompiler. This will allow
# us to inspect the GraphModule that was compiled by Thunder.
self.thunder_fns: list[ThunderModule] = []
self.thunder_to_gm: dict[ThunderModule, torch.fx.GraphModule] = {}
# compiled.
# Ref to the documentation of `SubgraphInfo` to know more about the information it contains.
self.subgraph_infos: list[SubgraphInfo] = []

self.thunder_options = thunder_options
torch_inductor_options = thunder_options.pop("torch_inductor_options", {})

# TODO: There will be pieces of Dynamo IR that Thunder cannot compile, so we
# will need to build a fallback mechanism to handle those cases.
# Possible stages of the compilation that need to be saved for inspection:
# 1. The GraphModule as it was passed to ThunderCompiler.
# 2. The GraphModule after split for Thunder/PyTorch.
# 3. If the whole GraphModule is not supported, record the reasons why.
self.thunder_options = thunder_options
self._thunder_jit = partial(jit, **thunder_options)
self._torch_compile = partial(torch.compile, **torch_inductor_options)

def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]):
from thunder import jit

# Dynamo uses lazy generation of the underlying Python code, so we need to
# force recompilation of the GraphModule before passing it to Thunder.
gm.real_recompile()

# Here in the future we could add some logic to check if the GraphModule
# is executable by Thunder, but for now we simply compile it and return
jitted_gm = jit(gm, **self.thunder_options)
self.thunder_fns.append(jitted_gm)
self.thunder_to_gm[jitted_gm] = gm
return jitted_gm
# The whole graph may not be supported by `thunder`, so we split it in `thunder` supported sections
# and unsupported sections which are passed to `torch.compile(backend='inductor')`
split_module, subgraph_info = _splitter(gm, self._thunder_jit, self._torch_compile, sample_args)
self.subgraph_infos.append(subgraph_info)
return split_module
173 changes: 173 additions & 0 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from typing import List, Dict, Optional, Tuple, Set
from collections.abc import Callable
from functools import partial

import torch
from torch.fx.passes.split_module import split_module
import warnings
from collections.abc import Mapping

from thunder.core.baseutils import run_once

from thunder.dynamo.utils import (
SubgraphInfo,
CompiledFunction,
CompilerType,
SplitReason,
SplitReasonType,
is_node_supported_by_thunder,
get_nodes_in_unsupported_ctx_regions,
update_node_and_submodule,
)


def _splitter(
gm: torch.fx.GraphModule,
thunder_jit: Callable,
torch_inductor: Callable,
_unused_sample_args: list[torch.SymInt, torch.Tensor],
) -> torch.fx.GraphModule:
"""
This method will split graph into multiple graph modules based on thunder supported operations.
This function will try to split the graph in contiguous partitions.

Example:
# All operations are supported by thunder
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[2]"):
l_x_ = L_x_

y: "f32[2]" = torch.sin(l_x_)
matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None
return (matmul,)

# Split Graph: All operations are supported by thunder, we will see only one partition.
class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[2]"):
thunder_1 = self.thunder_1(l_x_); l_x_ = None
return (thunder_1,)

class thunder_1(torch.nn.Module):
def forward(self, l_x_: "f32[2]"):
y: "f32[2]" = torch.sin(l_x_)
matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None
return matmul

Example:
# With unsupported operation `sinc`
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[2]"):
l_x_ = L_x_

y: "f32[2]" = torch.sinc(l_x_)

matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None
return (matmul,)

# Split Graph: Since `sinc` is unsupported, we will see two partitions, one for thunder and one for inductor.
class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[2]"):
inductor_1 = self.inductor_1(l_x_)
thunder_2 = self.thunder_2(l_x_, inductor_1); l_x_ = inductor_1 = None
return (thunder_2,)

class inductor_1(torch.nn.Module): # Partition for inductor
def forward(self, l_x_: "f32[2]"):
y: "f32[2]" = torch.sinc(l_x_); l_x_ = None
return y

class thunder_2(torch.nn.Module): # Partition for thunder
def forward(self, l_x_: "f32[2]", y: "f32[2]"):
matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None
return matmul
"""
# The callback below is called for every node in the graph.
# It returns an `int` denoting the parition where the node should be placed.
# We want to partition the graph into contiguous regions (with one or more operations)
# into thunder supported or unsupported region.
# `prev_value` is used to determine if we are still in same region (i.e. supported region or unsupported region).
# `partition_cnt` is bumped everytime we change the region i.e. flip from supported to unsupported or from unsupported to supported.
# `supported_partitions` is used to track the thunder supported partitions.
prev_value = None
partition_cnt = 0
supported_partitions: set[int] = set()
split_reasons: list[SplitReason] = []

nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm)

def callback(node) -> int:
nonlocal prev_value, partition_cnt, split_reasons, supported_partitions

assert node.op not in (
"placeholder",
"get_attr",
"output",
), f"fx.split_module should have only passed node.op=call_* but received {node.op}"

if node in nodes_in_unsupported_ctx_regions:
# If node was in unsupported ctx region like `autocast`,
# even though the operation maybe supported, we pass it to `torch.compile`
# as `thunder` doesn't correctly work with these.
is_thunder_supported = False
split_reason = SplitReason(
SplitReasonType.UNSUPPORTED_NODE,
info=f"node with name: {node.name} and target: {node.target} is not supported probably because it is in unsupported context.",
)
split_reasons.append(split_reason)
else:
is_thunder_supported, split_reason = is_node_supported_by_thunder(node)
if split_reason is not None:
split_reasons.append(split_reason)

if prev_value == is_thunder_supported: # We are in the same region.
return partition_cnt

# There is a flip. Either from supported to unsupported or unsupported to supported.
prev_value = is_thunder_supported
partition_cnt += 1 # Bump the region cnt.

if is_thunder_supported:
supported_partitions.add(partition_cnt)
return partition_cnt

# `split_module` iterates over nodes and determines the partition to place them based on the callback.
split_gm: torch.fx.GraphModule = split_module(
gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True
)

def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions

# Call compile on the split region/s.
thunder_compiled_fns = []
submodule_to_compiled_fns = {}
is_split = False
for node in split_gm.graph.nodes:
if is_thunder_supported_partition(node):
graph_module = getattr(split_gm, node.name)
jit_fn = thunder_jit(graph_module)
# Update the node name from "submod_*" to "thunder_*" for more user-friendly names
update_node_and_submodule(split_gm, node, node.name.replace("submod", "thunder"), jit_fn)
thunder_compiled_fns.append(jit_fn)
submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.THUNDER)
elif node.name.startswith("submod"): # For inductor
graph_module = getattr(split_gm, node.name)
jit_fn = torch_inductor(graph_module)
# Update the node name from "submod_*" to "inductor_*" for more user-friendly names
update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn)
submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.TORCH_INDUCTOR)
is_split = True
else:
# Everything else is a glue code to call and pass outputs between the other partitions.
pass

# We update the GraphModule in `update_node_and_submodule`, so we need to recompile.
split_gm.recompile()

return split_gm, SubgraphInfo(
gm,
split_gm,
thunder_compiled_fns,
submodule_to_compiled_fns,
split_reasons,
)
Loading
Loading