Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThunderFX: handles the callable input of fx.Node #1548

Merged
merged 15 commits into from
Jan 17, 2025
Merged

Conversation

kiya00
Copy link
Collaborator

@kiya00 kiya00 commented Dec 12, 2024

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #1539.

Needs #1463, #1620 to pass no_grad regions to thunder in thunderFX; #1568
As the analysis in #1539 (comment), this PR try to fix it by removing the unused torch.autograd.function.FunctionCtx() and processing the get_attr node

@kiya00 kiya00 marked this pull request as draft December 12, 2024 20:10
Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix, @kiya00 . The fix itself looks good, however the test failures look real.

Also, once everything works fine, I'd expect the following test to fail as it shouldn't find any module which was passed to inductor. (So it will need an update)

def test_splitter_autograd_function(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None):
x = torch.ones(2, device=device, dtype=dtype, requires_grad=True)

This line should raise an assertion error -

assert any(target.startswith("inductor_") for target in targets)

@kiya00
Copy link
Collaborator Author

kiya00 commented Dec 13, 2024

Hi @kshitij12345 I met 2 problems, do you maybe have any suggestions?

  1. When we use the graph.eliminate_dead_code() to remove the unused function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None, other inplace ops can be deleted, like in the case test_thundercompiler_optim_step. We need a function to specify if the node is impure if use graph.eliminate_dead_code(is_impure_node: Optional[Callable[[Node], bool]])
# graph in test_thundercompiler_optim_step
def forward(self, L_self_param_groups_0_params_0_grad : torch.Tensor, L_self_param_groups_0_params_1_grad : torch.Tensor, L_self_param_groups_0_params_2_grad : torch.Tensor, L_self_param_groups_0_params_3_grad : torch.Tensor):
    l_self_param_groups_0_params_0_grad = L_self_param_groups_0_params_0_grad
    l_self_param_groups_0_params_1_grad = L_self_param_groups_0_params_1_grad
    l_self_param_groups_0_params_2_grad = L_self_param_groups_0_params_2_grad
    l_self_param_groups_0_params_3_grad = L_self_param_groups_0_params_3_grad
    p = self.self___param_groups_0__params___0
    p_1 = self.self___param_groups_0__params___1
    p_2 = self.self___param_groups_0__params___2
    p_3 = self.self___param_groups_0__params___3
    _foreach_add_ = torch._foreach_add_([p, p_1, p_2, p_3], [l_self_param_groups_0_params_0_grad, l_self_param_groups_0_params_1_grad, l_self_param_groups_0_params_2_grad, l_self_param_groups_0_params_3_grad], alpha = -0.001);  p = p_1 = p_2 = p_3 = l_self_param_groups_0_params_0_grad = l_self_param_groups_0_params_1_grad = l_self_param_groups_0_params_2_grad = l_self_param_groups_0_params_3_grad = None
    return ()

Or do we set a specific pass to delete the torch.autograd.function.FunctionCtx()? or maybe support the torch.autograd.function.FunctionCtx op?

  1. when we check if thunder supports autograd_function_apply in splitter, I assume that we should take it as supported only if the 2 input fwd/bwd submodules are fully supported. But there's _set_grad_enabled(False) in the 2 submodules , which causes the autograd_function_apply to always be not supported by Thunder.

Here is the FX graph structure of autograd_function_apply

GraphModule(
  (fwd_body_0): GraphModule()
  (bwd_body_0): GraphModule()
)

def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
    l_x_ = L_x_
    cos = torch.cos(l_x_)
    function_ctx = torch.autograd.function.FunctionCtx();  function_ctx = None
    fwd_body_0 = self.fwd_body_0
    bwd_body_0 = self.bwd_body_0
    autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, s0, args_tensor_mask = [True], non_differentiable_idx = []);  fwd_body_0 = bwd_body_0 = s0 = None
    y = cos + autograd_function_apply;  cos = autograd_function_apply = None
    matmul = torch.matmul(l_x_, y);  l_x_ = y = None
    return (matmul,)

the fwd_body_0 module:

def forward(self, ctx : torch.autograd.function.Function, x : torch.Tensor, s0 : torch.SymInt):
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    sin = torch.sin(x)
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
    return (sin, [s0, x])

P.S. when we fix the above problems we also need to let the fwd_body_0 go through the converter, because for the same reason as the checkpoint operator, thunder couldn't trace the torch.sin in the submodule

cc: @IvanYashchuk

@kshitij12345
Copy link
Collaborator

kshitij12345 commented Dec 13, 2024

Regarding1., I am not very familiar with graph.eliminate_dead_code so if using that is tricky then adding a pass that just removes unused function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None sounds good. I am curious as to how torch.compile pipeline deals with this.

For 2., #1463 should take care of this as regions within _set_grad_enabled will be passed to thunder.

@kiya00
Copy link
Collaborator Author

kiya00 commented Dec 16, 2024

By cherry-pick #1463, the autograd_function_apply is compiled by thunder, but the trace has a problem with the _set_grad_enabled(False)
Here is the FX graph structure of autograd_function_apply

GraphModule(
  (fwd_body_0): GraphModule()
  (bwd_body_0): GraphModule()
)

def forward(self, s0 : torch.SymInt, L_x_ : torch.Tensor):
    l_x_ = L_x_
    cos = torch.cos(l_x_)
    function_ctx = torch.autograd.function.FunctionCtx();  function_ctx = None
    fwd_body_0 = self.fwd_body_0
    bwd_body_0 = self.bwd_body_0
    autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, s0, args_tensor_mask = [True], non_differentiable_idx = []);  fwd_body_0 = bwd_body_0 = s0 = None
    y = cos + autograd_function_apply;  cos = autograd_function_apply = None
    matmul = torch.matmul(l_x_, y);  l_x_ = y = None
    return (matmul,)

the fwd_body_0 module:

def forward(self, ctx : torch.autograd.function.Function, x : torch.Tensor, s0 : torch.SymInt):
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    sin = torch.sin(x)
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
    return (sin, [s0, x])

the augmented fwd trace is:

@torch.no_grad()
@no_autocast
def computation(l_x_):
  # l_x_: "cpu f32[2]"

  # <eval_with_key>.16:7:           cos = torch.cos(l_x_)
  cos = prims.cos(l_x_)  # cos: "cpu f32[2]"

  # <eval_with_key>.15:8:           _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
  autograd_function_apply = higher_order_autograd_function_apply_vvytl_130871302739856_0(None, l_x_)  # autograd_function_apply: "cpu f32[2]"
    # autograd_function_apply = ltorch.sin(l_x_)  # autograd_function_apply: "cpu f32[2]"
      # autograd_function_apply = prims.sin(l_x_)  # autograd_function_apply: "cpu f32[2]"

  # <eval_with_key>.16:9:           y = cos + autograd_function_apply;  cos = autograd_function_apply = None
  y = ltorch.add(cos, autograd_function_apply, alpha=1)  # y: "cpu f32[2]"
    # y = prims.add(cos, autograd_function_apply)  # y: "cpu f32[2]"

  # <eval_with_key>.16:10:          matmul = torch.matmul(l_x_, y);  l_x_ = y = None
  matmul = prims.matmul(l_x_, y)  # matmul: "cpu f32[]"
  return {'output': (matmul,), 'flat_args': [l_x_], 'flat_output': (matmul,)}, ((l_x_, y), ()), 

Note that the higher_order_autograd_function_apply_vvytl_130871302739856_0 corresponds to the above fwd_body_0 module and the tag is ProxyTag.DETACHED_AUTOGRAD_GRAPH, so the grad of it is skipped in the bwd trace

if is_constant_for_vjp(symbol):
# We can skip the pullback if all the arguments are constant
continue

I think the correct behavior should be to use the grad transformation:
def grad_transform(*args, **kwargs):

Do we maybe remove the torch._C._set_grad_enabled in the fwd_body_0 module? @kshitij12345 @IvanYashchuk , do you have some suggestions?
It seems if I let is_constant_for_vjp(symbol of higher_order_autograd_function_apply) return False even if the output of higher_order_autograd_function_apply has tag ProxyTag.DETACHED_AUTOGRAD_GRAPH, the backward trace is expected

@kshitij12345
Copy link
Collaborator

Thanks for the explanation @kiya00.

I was wondering that since the line below goes through thunder.jit, the relevant code from jit_ext which creates a new symbol with a grad rule should be applied automatically, right? What is the backward trace that is being generated?

autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, s0, args_tensor_mask = [True], non_differentiable_idx = []);

@kiya00
Copy link
Collaborator Author

kiya00 commented Dec 18, 2024

I was wondering that since the line below goes through thunder.jit, the relevant code from jit_ext which creates a new symbol with a grad rule should be applied automatically, right? What is the backward trace that is being generated?

yes, it's automatically, the bwd of symbol torch.ops.higher_order.autograd_function_apply is created in backward_pass L2764, but because the fwd_body_0 has no_grad, it continues in L2744(so no bwd is added for autograd_function_apply)

if is_constant_for_vjp(symbol):
# We can skip the pullback if all the arguments are constant
continue
if all(cotangent is None for cotangent in cotangents):
# We can skip the pullback if the cotangent is None
safe_map(put_grad, symbol.args, (None,) * len(symbol.args))
continue
if symbol.sym.id == "torch.nn.functional.dropout" and not symbol.subsymbols:
# We can skip the pullback if the dropout probability is 0.0
# Assuming that the dropout symbol has the same output and argument
assert symbol.output.name == symbol.args[0].name, "Dropout symbol has a different output and argument"
if symbol.args[1] == 0.0 or symbol.args[2] is False:
continue
backward = backward_impls.get(symbol.sym.id)
aug_forward = augmented_forward_impls.get(symbol.sym.id)
if _get_gradfn_and_executor(symbol)[0] is not None:
aug_forward, backward = make_aug_forward_and_backward(symbol)

so I tried to let is_constant_for_vjp always return False for autograd_function_apply
then the fwd trace is(corresponds to the test case in the PR):

@torch.no_grad()
@no_autocast
def computation(l_x_):
  # l_x_: "cuda:0 f32[2]"
  [y] = nvFusion0(l_x_)
    # cos = prims.cos(l_x_)  # cos: "cuda:0 f32[2]"
    # autograd_function_apply = prims.sin(l_x_)  # autograd_function_apply: "cuda:0 f32[2]"
    # y = prims.add(cos, autograd_function_apply)  # y: "cuda:0 f32[2]"

  # <eval_with_key>.121:10:         matmul = torch.matmul(l_x_, y);  l_x_ = y = None
  matmul = torch.matmul(l_x_, y)  # matmul: "cuda:0 f32[]"
    # matmul = ltorch.matmul(l_x_, y)  # matmul: "cuda:0 f32[]"
      # matmul = prims.matmul(l_x_, y)  # matmul: "cuda:0 f32[]"
  return {'output': (matmul,), 'flat_args': [l_x_], 'flat_output': (matmul,)}, ((l_x_, y), ())

the bwd trace:

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t0, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  l_x_, y, = C0
  clear_mutable_collection(C0)
  del C0
  [t28] = nvFusion0(l_x_, t0, y)
    # t21 = prims.cos(l_x_)  # t21: "cuda:0 f32[2]"
    # t25 = prims.sin(l_x_)  # t25: "cuda:0 f32[2]"
    # t17 = prims.broadcast_in_dim(t0, (2,), ())  # t17: "cuda:0 f32[2]"
    # t18 = prims.mul(t17, y)  # t18: "cuda:0 f32[2]"
    # t20 = prims.mul(t17, l_x_)  # t20: "cuda:0 f32[2]"
    # t22 = prims.mul(t20, t21)  # t22: "cuda:0 f32[2]"
    # t23 = prims.mul(t22, 100.0)  # t23: "cuda:0 f32[2]"
    # t24 = prims.add(t18, t23)  # t24: "cuda:0 f32[2]"
    # t26 = prims.neg(t25)  # t26: "cuda:0 f32[2]"
    # t27 = prims.mul(t20, t26)  # t27: "cuda:0 f32[2]"
    # t28 = prims.add(t24, t27)  # t28: "cuda:0 f32[2]"
  del l_x_, t0, y
  return (t28,)

@kiya00 kiya00 marked this pull request as ready for review January 13, 2025 13:51
@kiya00
Copy link
Collaborator Author

kiya00 commented Jan 13, 2025

Hi @kshitij12345 @IvanYashchuk , it's ready for review

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have left small suggestions and have tagged @crcrpar to review jit_ext.py as I think he is very familiar with that. Other than that everything looks really good, thank you @kiya00

thunder/dynamo/utils.py Outdated Show resolved Hide resolved
thunder/tests/test_dynamo.py Show resolved Hide resolved
thunder/dynamo/utils.py Outdated Show resolved Hide resolved
thunder/dynamo/utils.py Outdated Show resolved Hide resolved
thunder/core/jit_ext.py Show resolved Hide resolved
thunder/core/jit_ext.py Outdated Show resolved Hide resolved
thunder/core/jit_ext.py Outdated Show resolved Hide resolved
thunder/core/jit_ext.py Show resolved Hide resolved
kiya00 and others added 2 commits January 16, 2025 14:09
Co-authored-by: Masaki Kozuki <mkozuki@nvidia.com>
Co-authored-by: Kshiteej K <kshitijkalambarkar@gmail.com>
@kiya00
Copy link
Collaborator Author

kiya00 commented Jan 16, 2025

Hi @IvanYashchuk do you want to take another look before merging?

Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's merge. Thank you all!

@IvanYashchuk IvanYashchuk enabled auto-merge (squash) January 17, 2025 10:28
@IvanYashchuk
Copy link
Collaborator

@t-vi, could you please approve this one?

@IvanYashchuk IvanYashchuk added thunderfx for things that could be applicable to the dynamo+thunder frontend autograd labels Jan 17, 2025
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IvanYashchuk IvanYashchuk merged commit 628958d into main Jan 17, 2025
49 checks passed
@IvanYashchuk IvanYashchuk deleted the tryfix1539 branch January 17, 2025 16:37
riccardofelluga pushed a commit that referenced this pull request Jan 27, 2025
Fixes #1539.

Needs #1463, #1620 to pass no_grad regions to thunder in thunderFX; #1568
As the analysis in #1539 (comment), this PR try to fix it by removing the unused torch.autograd.function.FunctionCtx() and processing the get_attr node
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
autograd thunderfx for things that could be applicable to the dynamo+thunder frontend
Projects
None yet
5 participants