forked from Lightning-AI/lightning-thunder
-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add splitter to dynamo backend #5
Closed
Closed
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
9f3fc95
add splitter to dynamo backend
kshitij12345 ea3de2b
failure with dynamic=True
kshitij12345 a5158d0
address review : part 1
kshitij12345 f22a921
remove experimental code
kshitij12345 31bbc30
address review : part 2
kshitij12345 734a3be
update code and add comments
kshitij12345 e28ec46
add comment
kshitij12345 fbe6737
test for submodule name
kshitij12345 ecde8d5
address review
kshitij12345 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
from typing import List, Dict, Optional, Tuple, Set | ||
from collections.abc import Callable | ||
from functools import partial | ||
|
||
import torch | ||
from torch.fx.passes.split_module import split_module | ||
import warnings | ||
from collections.abc import Mapping | ||
|
||
from thunder.core.baseutils import run_once | ||
|
||
from thunder.dynamo.utils import ( | ||
SubgraphInfo, | ||
CompiledFunction, | ||
CompilerType, | ||
SplitReason, | ||
SplitReasonType, | ||
is_node_supported_by_thunder, | ||
get_nodes_in_unsupported_ctx_regions, | ||
update_node_and_submodule, | ||
) | ||
|
||
|
||
def _splitter( | ||
gm: torch.fx.GraphModule, | ||
thunder_jit: Callable, | ||
torch_inductor: Callable, | ||
_unused_sample_args: list[torch.SymInt, torch.Tensor], | ||
) -> torch.fx.GraphModule: | ||
""" | ||
This method will split graph into multiple graph modules based on thunder supported operations. | ||
This function will try to split the graph in contiguous partitions. | ||
|
||
Example: | ||
# All operations are supported by thunder | ||
class GraphModule(torch.nn.Module): | ||
def forward(self, L_x_: "f32[2]"): | ||
l_x_ = L_x_ | ||
|
||
y: "f32[2]" = torch.sin(l_x_) | ||
matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None | ||
return (matmul,) | ||
|
||
# Split Graph: All operations are supported by thunder, we will see only one partition. | ||
class GraphModule(torch.nn.Module): | ||
def forward(self, l_x_: "f32[2]"): | ||
thunder_1 = self.thunder_1(l_x_); l_x_ = None | ||
return (thunder_1,) | ||
|
||
class thunder_1(torch.nn.Module): | ||
def forward(self, l_x_: "f32[2]"): | ||
y: "f32[2]" = torch.sin(l_x_) | ||
matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None | ||
return matmul | ||
|
||
Example: | ||
# With unsupported operation `sinc` | ||
class GraphModule(torch.nn.Module): | ||
def forward(self, L_x_: "f32[2]"): | ||
l_x_ = L_x_ | ||
|
||
y: "f32[2]" = torch.sinc(l_x_) | ||
|
||
matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None | ||
return (matmul,) | ||
|
||
# Split Graph: Since `sinc` is unsupported, we will see two partitions, one for thunder and one for inductor. | ||
class GraphModule(torch.nn.Module): | ||
def forward(self, l_x_: "f32[2]"): | ||
inductor_1 = self.inductor_1(l_x_) | ||
thunder_2 = self.thunder_2(l_x_, inductor_1); l_x_ = inductor_1 = None | ||
return (thunder_2,) | ||
|
||
class inductor_1(torch.nn.Module): # Partition for inductor | ||
def forward(self, l_x_: "f32[2]"): | ||
y: "f32[2]" = torch.sinc(l_x_); l_x_ = None | ||
return y | ||
|
||
class thunder_2(torch.nn.Module): # Partition for thunder | ||
def forward(self, l_x_: "f32[2]", y: "f32[2]"): | ||
matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None | ||
return matmul | ||
""" | ||
# The callback below is called for every node in the graph. | ||
# It returns an `int` denoting the parition where the node should be placed. | ||
# We want to partition the graph into contiguous regions (with one or more operations) | ||
# into thunder supported or unsupported region. | ||
# `prev_value` is used to determine if we are still in same region (i.e. supported region or unsupported region). | ||
# `partition_cnt` is bumped everytime we change the region i.e. flip from supported to unsupported or from unsupported to supported. | ||
# `supported_partitions` is used to track the thunder supported partitions. | ||
prev_value = None | ||
partition_cnt = 0 | ||
supported_partitions: set[int] = set() | ||
split_reasons: list[SplitReason] = [] | ||
|
||
nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm) | ||
|
||
def callback(node) -> int: | ||
nonlocal prev_value, partition_cnt, split_reasons, supported_partitions | ||
|
||
assert node.op not in ( | ||
"placeholder", | ||
"get_attr", | ||
"output", | ||
), f"fx.split_module should have only passed node.op=call_* but received {node.op}" | ||
|
||
if node in nodes_in_unsupported_ctx_regions: | ||
# If node was in unsupported ctx region like `autocast`, | ||
# even though the operation maybe supported, we pass it to `torch.compile` | ||
# as `thunder` doesn't correctly work with these. | ||
is_thunder_supported = False | ||
split_reason = SplitReason( | ||
SplitReasonType.UNSUPPORTED_NODE, | ||
info=f"node with name: {node.name} and target: {node.target} is not supported probably because it is in unsupported context.", | ||
) | ||
split_reasons.append(split_reason) | ||
else: | ||
is_thunder_supported, split_reason = is_node_supported_by_thunder(node) | ||
if split_reason is not None: | ||
split_reasons.append(split_reason) | ||
|
||
if prev_value == is_thunder_supported: # We are in the same region. | ||
return partition_cnt | ||
|
||
# There is a flip. Either from supported to unsupported or unsupported to supported. | ||
prev_value = is_thunder_supported | ||
partition_cnt += 1 # Bump the region cnt. | ||
|
||
if is_thunder_supported: | ||
supported_partitions.add(partition_cnt) | ||
return partition_cnt | ||
|
||
# `split_module` iterates over nodes and determines the partition to place them based on the callback. | ||
split_gm: torch.fx.GraphModule = split_module( | ||
gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True | ||
) | ||
|
||
def is_thunder_supported_partition(node: torch.fx.Node) -> bool: | ||
return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions | ||
|
||
# Call compile on the split region/s. | ||
thunder_compiled_fns = [] | ||
submodule_to_compiled_fns = {} | ||
is_split = False | ||
for node in split_gm.graph.nodes: | ||
if is_thunder_supported_partition(node): | ||
graph_module = getattr(split_gm, node.name) | ||
jit_fn = thunder_jit(graph_module) | ||
# Update the node name from "submod_*" to "thunder_*" for more user-friendly names | ||
update_node_and_submodule(split_gm, node, node.name.replace("submod", "thunder"), jit_fn) | ||
thunder_compiled_fns.append(jit_fn) | ||
submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.THUNDER) | ||
elif node.name.startswith("submod"): # For inductor | ||
graph_module = getattr(split_gm, node.name) | ||
jit_fn = torch_inductor(graph_module) | ||
# Update the node name from "submod_*" to "inductor_*" for more user-friendly names | ||
update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn) | ||
submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.TORCH_INDUCTOR) | ||
is_split = True | ||
else: | ||
# Everything else is a glue code to call and pass outputs between the other partitions. | ||
pass | ||
|
||
# We update the GraphModule in `update_node_and_submodule`, so we need to recompile. | ||
split_gm.recompile() | ||
|
||
return split_gm, SubgraphInfo( | ||
gm, | ||
split_gm, | ||
thunder_compiled_fns, | ||
submodule_to_compiled_fns, | ||
split_reasons, | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's create a new file
splitter.py
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have created a new file
utils.py
as it has the code for dataclasses which are returned (which is not directly related to splitter) and splitter related functions.