Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add splitter to dynamo backend #5

Closed
wants to merge 9 commits into from
Closed

Conversation

kshitij12345
Copy link
Owner

@kshitij12345 kshitij12345 commented Aug 24, 2024

This PR is to facilitate to viewing the complete changes and leaving comments.

Figure out

  1. Default Dtype for factory functions.
import thunder
from thunder.core.proxies import proxy
from thunder.core.trace import tracectx, TraceCtx
import torch

my_trc = TraceCtx()
with tracectx(my_trc):
    thunder.torch.ones(3, 3)
  1. SymInts when they are part of the shape of FakeTensor. (Know how this will interact with future Dynamic Shape support in thunder)

@kshitij12345 kshitij12345 force-pushed the torch-compile-splitter branch from de6c681 to 9f3fc95 Compare August 26, 2024 11:22
Comment on lines 36 to 37
op_support = ThunderOperatorSupport(gm_copy)
partitioner = CapabilityBasedPartitioner(gm_copy, op_support)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you like their design of instantiating a separate class for operator support? I don't like and here's how we can avoid using it (by implementing __is_node_supported):

class ThunderPartitioner(CapabilityBasedPartitioner):
    def __init__(self, gm):
        supported_ops = None # This is overridden by __is_node_supported
        super().__init__(gm, supported_ops, allows_single_node_partition=True)

    def __is_node_supported(self, node: torch.fx.Node) -> bool:
        # node.meta could be inspected for metadata about the node
        # For now, we just allow all nodes. Note that this is not a good idea
        # and ThunderSupport should be updated to only allow nodes that are
        # supported by Thunder
        return True

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this subclass we can also override fuse_partitions to support passing our own name:

class ThunderPartitioner(CapabilityBasedPartitioner):
    ...
    def fuse_partitions(self, partitions: list[Partition]) -> torch.fx.GraphModule:
        # fuse_by_partitions expects partitions in list[List[Node]]: [ [node0, node1], [node2, node3] ]
        return fuse_by_partitions(self.graph_module, [list(partition.nodes) for partition in partitions])

# fx.passes.utils.fuser_utils.fuse_by_partitions but with the name_prefix parameter
def fuse_by_partitions(gm: torch.fx.GraphModule, partitions: list[torch.fx.Node], name_prefix: str) -> torch.fx.GraphModule:
    from torch.fx.passes.utils.fuser_utils import topo_sort, insert_subgm, erase_nodes, fuse_as_graphmodule, legalize_graph
    for partition_id, nodes in enumerate(partitions):
        sorted_nodes = topo_sort(nodes)

        submodule_name = name_prefix + str(partition_id)
        sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name)

        insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)
        erase_nodes(gm, sorted_nodes)

    legalize_graph(gm)
    return gm

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have removed the usage of OperatorSupport class and have function instead.

Also, the function only cares about whether a node is supported by thunder and the logic whether it is in unsupported ctx region has been moved outside.

gm_copy = copy.deepcopy(gm)
op_support = ThunderOperatorSupport(gm_copy)
partitioner = CapabilityBasedPartitioner(gm_copy, op_support)
fused_partition = partitioner.partition_and_fuse()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is so small that I think it's more readable to use directly what it's inside:

Suggested change
fused_partition = partitioner.partition_and_fuse()
partitions = partitioner.propose_partitions()
fused_partition = partitioner.fuse_partitions(partitions)

Now we can implement our own .fuse_partitions which is better because it can use more descriptive name than simply "fused_#"

# return matmul


class GraphModuleSplitter(torch.fx.passes.splitter_base._SplitterBase):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, they have a "splitter" and a "partitioner"? 🫨 What's the difference?

supported = True
for node in gm.graph.nodes:
if node.op in ["call_method", "call_function"]:
supported = op_support.is_node_supported(gm, node)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is ThunderOperatorSupport really needed? Can it be simply a function is_node_supported_by_thunder?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have removed the usage of OperatorSupport class and have function is_node_supported_by_thunder instead.

Also, the function only cares about whether a node is supported by thunder and the logic whether it is in unsupported ctx region has been moved outside.

thunder/dynamo/compiler.py Outdated Show resolved Hide resolved
)

def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of "submod" can we use more meaningful names from the start like "thunder_#" and "inductor_#"?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To have meaningful names from the start (i.e. from the split_module function), we will have to have our own implementation of the same.

However, I am bit a sceptical on doing that as I see there have been few real fixes on the file in the last year or so (not just refactoring and better engineering).
https://github.com/pytorch/pytorch/commits/main/torch/fx/passes/split_module.py

Instead, we can update the graph module returned by split_module and rename the submodules so that user will see informative names like thunder_* and inductor_*. This is what the current code is doing.

Wanted to know your opinion on the same.

Example of Split subgraph that user would currently see -

class GraphModule(torch.nn.Module):
    def forward(self, l_x_: "f32[2]"):
        # No stacktrace found for following nodes
        thunder_1 = self.thunder_1(l_x_);  l_x_ = None
        return (thunder_1,)
        
    class thunder_1(torch.nn.Module):
        def forward(self, l_x_: "f32[2]"):
             # File: /home/kkalambarkar/lightning-thunder/scratchpad/test_splitter.py:13 in func, code: y = torch.sin(x) + x * (x + 1) + torch.ones(2)
            sin: "f32[2]" = torch.sin(l_x_)
            add: "f32[2]" = l_x_ + 1
            mul: "f32[2]" = l_x_ * add;  add = None
            add_1: "f32[2]" = sin + mul;  sin = mul = None
            ones: "f32[2]" = torch.ones(2)
            y: "f32[2]" = add_1 + ones;  add_1 = ones = None
            
             # File: /home/kkalambarkar/lightning-thunder/scratchpad/test_splitter.py:14 in func, code: return torch.matmul(x, y)
            matmul: "f32[]" = torch.matmul(l_x_, y);  l_x_ = y = None
            return matmul
            
        class _model(torch.nn.Module):
            def forward(self, l_x_: "f32[2]"):
                 # File: /home/kkalambarkar/lightning-thunder/scratchpad/test_splitter.py:13 in func, code: y = torch.sin(x) + x * (x + 1) + torch.ones(2)
                sin: "f32[2]" = torch.sin(l_x_)
                add: "f32[2]" = l_x_ + 1
                mul: "f32[2]" = l_x_ * add;  add = None
                add_1: "f32[2]" = sin + mul;  sin = mul = None
                ones: "f32[2]" = torch.ones(2)
                y: "f32[2]" = add_1 + ones;  add_1 = ones = None
                
                 # File: /home/kkalambarkar/lightning-thunder/scratchpad/test_splitter.py:14 in func, code: return torch.matmul(x, y)
                matmul: "f32[]" = torch.matmul(l_x_, y);  l_x_ = y = None
                return matmul
            

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much of the file relies on split_callback(node) returning an integer? What happens if it starts to return a string that would be used here for the name
https://github.com/pytorch/pytorch/blob/50efbb9f1e7111b4b6d5b8e9a6064ee9783930be/torch/fx/passes/split_module.py#L236

Another option is to patch the Partition class and modify submod_name and then restore the original class implementation.
https://github.com/pytorch/pytorch/blob/50efbb9f1e7111b4b6d5b8e9a6064ee9783930be/torch/fx/passes/split_module.py#L18

Instead, we can update the graph module returned by split_module and rename the submodules so that user will see informative names like thunder_* and inductor_*. This is what the current code is doing.

This is great! No need to come up with workarounds.

thunder/tests/test_dynamo.py Show resolved Hide resolved
thunder/tests/test_dynamo.py Outdated Show resolved Hide resolved
thunder/tests/test_dynamo.py Outdated Show resolved Hide resolved

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's create a new file splitter.py.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have created a new file utils.py as it has the code for dataclasses which are returned (which is not directly related to splitter) and splitter related functions.

@thunder._with_cache_info_ctx
def _run_with_cache_info():

# We need cache info here as the default dtype and device support
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should pursue this seperately after this PR. Wdyt?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, definitely a separate PR type of work. Can Thunder functions query PyTorch themselves for the current default type and device? If thunder.new_ones implementation is modified to do that we don't need to query the cache and pass it explicitly.

@kshitij12345
Copy link
Owner Author

@IvanYashchuk, I have made most of the changes from our discussion yesterday (have dropped comments where I felt otherwise) and the PR should be ready for another look. Thanks!

The failure on Windows look real - investigating them - could be related to pytorch/pytorch#122094

Copy link

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Let's move this to the Thunder repo!

@@ -13,14 +31,15 @@ def _warn_thunder_compiler():


class ThunderCompiler:
def __init__(self, **thunder_options):
def __init__(self, *, thunder_options: dict | None = None, torch_inductor_options: dict | None = None):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's less cognitive load to switch from thunder.jit to ThunderCompiler if the way to pass Thunder options is the same.
I suggest keeping **kwargs and then popping "torch_inductor_options" from this dict. For example specifying executors for Thunder and a torch.compile option would look like:

backend = ThunderCompiler(executors=[...], torch_inductor_options={"mode": "max-autotune"})
compiled_module = torch.compile(backend=backend)(module)

Comment on lines 79 to 81
def _splitter(
self, gm: torch.fx.GraphModule, _unused_sample_args: list[torch.SymInt, torch.Tensor]
) -> torch.fx.GraphModule:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method doesn't seem to use the self attribute. Let's move to a separate file splitter.py.

)

def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much of the file relies on split_callback(node) returning an integer? What happens if it starts to return a string that would be used here for the name
https://github.com/pytorch/pytorch/blob/50efbb9f1e7111b4b6d5b8e9a6064ee9783930be/torch/fx/passes/split_module.py#L236

Another option is to patch the Partition class and modify submod_name and then restore the original class implementation.
https://github.com/pytorch/pytorch/blob/50efbb9f1e7111b4b6d5b8e9a6064ee9783930be/torch/fx/passes/split_module.py#L18

Instead, we can update the graph module returned by split_module and rename the submodules so that user will see informative names like thunder_* and inductor_*. This is what the current code is doing.

This is great! No need to come up with workarounds.

@thunder._with_cache_info_ctx
def _run_with_cache_info():

# We need cache info here as the default dtype and device support

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, definitely a separate PR type of work. Can Thunder functions query PyTorch themselves for the current default type and device? If thunder.new_ones implementation is modified to do that we don't need to query the cache and pass it explicitly.

thunder/tests/test_dynamo.py Outdated Show resolved Hide resolved
@IvanYashchuk
Copy link

The failure on Windows look real - investigating them - could be related to pytorch/pytorch#122094

Yes, on Windows let's error out saying that torch.compile with Inductor is not supported. nvFuser currently is also not available on Windows but it was working before when it was part of the PyTorch code base, maybe it's not too much work to enable it, but it's more work to maintain the CI. Tests should be skipped then on Windows.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants