-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add splitter to dynamo backend #5
Conversation
de6c681
to
9f3fc95
Compare
op_support = ThunderOperatorSupport(gm_copy) | ||
partitioner = CapabilityBasedPartitioner(gm_copy, op_support) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do you like their design of instantiating a separate class for operator support? I don't like and here's how we can avoid using it (by implementing __is_node_supported
):
class ThunderPartitioner(CapabilityBasedPartitioner):
def __init__(self, gm):
supported_ops = None # This is overridden by __is_node_supported
super().__init__(gm, supported_ops, allows_single_node_partition=True)
def __is_node_supported(self, node: torch.fx.Node) -> bool:
# node.meta could be inspected for metadata about the node
# For now, we just allow all nodes. Note that this is not a good idea
# and ThunderSupport should be updated to only allow nodes that are
# supported by Thunder
return True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this subclass we can also override fuse_partitions
to support passing our own name:
class ThunderPartitioner(CapabilityBasedPartitioner):
...
def fuse_partitions(self, partitions: list[Partition]) -> torch.fx.GraphModule:
# fuse_by_partitions expects partitions in list[List[Node]]: [ [node0, node1], [node2, node3] ]
return fuse_by_partitions(self.graph_module, [list(partition.nodes) for partition in partitions])
# fx.passes.utils.fuser_utils.fuse_by_partitions but with the name_prefix parameter
def fuse_by_partitions(gm: torch.fx.GraphModule, partitions: list[torch.fx.Node], name_prefix: str) -> torch.fx.GraphModule:
from torch.fx.passes.utils.fuser_utils import topo_sort, insert_subgm, erase_nodes, fuse_as_graphmodule, legalize_graph
for partition_id, nodes in enumerate(partitions):
sorted_nodes = topo_sort(nodes)
submodule_name = name_prefix + str(partition_id)
sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name)
insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)
erase_nodes(gm, sorted_nodes)
legalize_graph(gm)
return gm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have removed the usage of OperatorSupport
class and have function instead.
Also, the function only cares about whether a node is supported by thunder and the logic whether it is in unsupported ctx region has been moved outside.
gm_copy = copy.deepcopy(gm) | ||
op_support = ThunderOperatorSupport(gm_copy) | ||
partitioner = CapabilityBasedPartitioner(gm_copy, op_support) | ||
fused_partition = partitioner.partition_and_fuse() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is so small that I think it's more readable to use directly what it's inside:
fused_partition = partitioner.partition_and_fuse() | |
partitions = partitioner.propose_partitions() | |
fused_partition = partitioner.fuse_partitions(partitions) |
Now we can implement our own .fuse_partitions
which is better because it can use more descriptive name than simply "fused_#"
# return matmul | ||
|
||
|
||
class GraphModuleSplitter(torch.fx.passes.splitter_base._SplitterBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow, they have a "splitter" and a "partitioner"? 🫨 What's the difference?
thunder/dynamo/compiler.py
Outdated
supported = True | ||
for node in gm.graph.nodes: | ||
if node.op in ["call_method", "call_function"]: | ||
supported = op_support.is_node_supported(gm, node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is ThunderOperatorSupport
really needed? Can it be simply a function is_node_supported_by_thunder
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have removed the usage of OperatorSupport class and have function is_node_supported_by_thunder
instead.
Also, the function only cares about whether a node is supported by thunder and the logic whether it is in unsupported ctx region has been moved outside.
thunder/dynamo/compiler.py
Outdated
) | ||
|
||
def is_thunder_supported_partition(node: torch.fx.Node) -> bool: | ||
return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of "submod" can we use more meaningful names from the start like "thunder_#" and "inductor_#"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To have meaningful names from the start (i.e. from the split_module
function), we will have to have our own implementation of the same.
However, I am bit a sceptical on doing that as I see there have been few real fixes on the file in the last year or so (not just refactoring and better engineering).
https://github.com/pytorch/pytorch/commits/main/torch/fx/passes/split_module.py
Instead, we can update the graph module returned by split_module
and rename the submodules so that user will see informative names like thunder_*
and inductor_*
. This is what the current code is doing.
Wanted to know your opinion on the same.
Example of Split subgraph that user would currently see -
class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[2]"):
# No stacktrace found for following nodes
thunder_1 = self.thunder_1(l_x_); l_x_ = None
return (thunder_1,)
class thunder_1(torch.nn.Module):
def forward(self, l_x_: "f32[2]"):
# File: /home/kkalambarkar/lightning-thunder/scratchpad/test_splitter.py:13 in func, code: y = torch.sin(x) + x * (x + 1) + torch.ones(2)
sin: "f32[2]" = torch.sin(l_x_)
add: "f32[2]" = l_x_ + 1
mul: "f32[2]" = l_x_ * add; add = None
add_1: "f32[2]" = sin + mul; sin = mul = None
ones: "f32[2]" = torch.ones(2)
y: "f32[2]" = add_1 + ones; add_1 = ones = None
# File: /home/kkalambarkar/lightning-thunder/scratchpad/test_splitter.py:14 in func, code: return torch.matmul(x, y)
matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None
return matmul
class _model(torch.nn.Module):
def forward(self, l_x_: "f32[2]"):
# File: /home/kkalambarkar/lightning-thunder/scratchpad/test_splitter.py:13 in func, code: y = torch.sin(x) + x * (x + 1) + torch.ones(2)
sin: "f32[2]" = torch.sin(l_x_)
add: "f32[2]" = l_x_ + 1
mul: "f32[2]" = l_x_ * add; add = None
add_1: "f32[2]" = sin + mul; sin = mul = None
ones: "f32[2]" = torch.ones(2)
y: "f32[2]" = add_1 + ones; add_1 = ones = None
# File: /home/kkalambarkar/lightning-thunder/scratchpad/test_splitter.py:14 in func, code: return torch.matmul(x, y)
matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None
return matmul
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How much of the file relies on split_callback(node)
returning an integer? What happens if it starts to return a string that would be used here for the name
https://github.com/pytorch/pytorch/blob/50efbb9f1e7111b4b6d5b8e9a6064ee9783930be/torch/fx/passes/split_module.py#L236
Another option is to patch the Partition class and modify submod_name
and then restore the original class implementation.
https://github.com/pytorch/pytorch/blob/50efbb9f1e7111b4b6d5b8e9a6064ee9783930be/torch/fx/passes/split_module.py#L18
Instead, we can update the graph module returned by split_module and rename the submodules so that user will see informative names like thunder_* and inductor_*. This is what the current code is doing.
This is great! No need to come up with workarounds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's create a new file splitter.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have created a new file utils.py
as it has the code for dataclasses which are returned (which is not directly related to splitter) and splitter related functions.
@thunder._with_cache_info_ctx | ||
def _run_with_cache_info(): | ||
|
||
# We need cache info here as the default dtype and device support |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should pursue this seperately after this PR. Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, definitely a separate PR type of work. Can Thunder functions query PyTorch themselves for the current default type and device? If thunder.new_ones
implementation is modified to do that we don't need to query the cache and pass it explicitly.
@IvanYashchuk, I have made most of the changes from our discussion yesterday (have dropped comments where I felt otherwise) and the PR should be ready for another look. Thanks! The failure on Windows look real - investigating them - could be related to pytorch/pytorch#122094 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Let's move this to the Thunder repo!
thunder/dynamo/compiler.py
Outdated
@@ -13,14 +31,15 @@ def _warn_thunder_compiler(): | |||
|
|||
|
|||
class ThunderCompiler: | |||
def __init__(self, **thunder_options): | |||
def __init__(self, *, thunder_options: dict | None = None, torch_inductor_options: dict | None = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's less cognitive load to switch from thunder.jit
to ThunderCompiler
if the way to pass Thunder options is the same.
I suggest keeping **kwargs
and then popping "torch_inductor_options" from this dict. For example specifying executors for Thunder and a torch.compile option would look like:
backend = ThunderCompiler(executors=[...], torch_inductor_options={"mode": "max-autotune"})
compiled_module = torch.compile(backend=backend)(module)
thunder/dynamo/compiler.py
Outdated
def _splitter( | ||
self, gm: torch.fx.GraphModule, _unused_sample_args: list[torch.SymInt, torch.Tensor] | ||
) -> torch.fx.GraphModule: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method doesn't seem to use the self
attribute. Let's move to a separate file splitter.py
.
thunder/dynamo/compiler.py
Outdated
) | ||
|
||
def is_thunder_supported_partition(node: torch.fx.Node) -> bool: | ||
return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How much of the file relies on split_callback(node)
returning an integer? What happens if it starts to return a string that would be used here for the name
https://github.com/pytorch/pytorch/blob/50efbb9f1e7111b4b6d5b8e9a6064ee9783930be/torch/fx/passes/split_module.py#L236
Another option is to patch the Partition class and modify submod_name
and then restore the original class implementation.
https://github.com/pytorch/pytorch/blob/50efbb9f1e7111b4b6d5b8e9a6064ee9783930be/torch/fx/passes/split_module.py#L18
Instead, we can update the graph module returned by split_module and rename the submodules so that user will see informative names like thunder_* and inductor_*. This is what the current code is doing.
This is great! No need to come up with workarounds.
@thunder._with_cache_info_ctx | ||
def _run_with_cache_info(): | ||
|
||
# We need cache info here as the default dtype and device support |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, definitely a separate PR type of work. Can Thunder functions query PyTorch themselves for the current default type and device? If thunder.new_ones
implementation is modified to do that we don't need to query the cache and pass it explicitly.
Yes, on Windows let's error out saying that |
This PR is to facilitate to viewing the complete changes and leaving comments.
Figure out