Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functionalize in-place ops #584

Merged
merged 38 commits into from
Jun 20, 2024
Merged

Functionalize in-place ops #584

merged 38 commits into from
Jun 20, 2024

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Jun 12, 2024

What does this PR do?

Express in-place ops using their out-of-place counterpart with following prims.copy_. e.g. a.add_(b) as prims.copy_(prims.add(a, b), a), then let the added transform functionalize the trace by removing the trailing copy amd updating the signature.
Let’s say we have t.exp_() in a script and t is used afterwards, thunder translates it into t_out = ltorch.exp_(t). This bound symbol has two sub bound symbols: t0 = ltorch.exp(t) and t_out = prims.copy_(t0, t). The functionalization removes the copy, and replaces ltorch.exp_(t) with t0 = ltorch.exp(t) and t uses after exp_ with t0.

The covered ops are ones that either (a) have in-place variant such as torch.exp and torch.add or (b) have inplace as one of their args such as torch.nn.functional.relu.

So this would not cover the entire #145 broadly, nor take aliases into considerations.

thunder/core/proxies.py Outdated Show resolved Hide resolved
@t-vi
Copy link
Collaborator

t-vi commented Jun 12, 2024

Awesome stuff. The fundamental questions I have about this are

@crcrpar
Copy link
Collaborator Author

crcrpar commented Jun 12, 2024

  • Does this give silent errors? What do we do about this

could you give me some examples of silent errors? I'm not quite following what they would be.

"the transition" of what? The better coverage by supporting in-place ops?
Anyway I'd like to give it a try as well. I just am not quite clear about what would look better even in the short term

@t-vi
Copy link
Collaborator

t-vi commented Jun 12, 2024

could you give me some examples of silent errors? I'm not quite following what they would be.

Essentially things that rely on aliases (x = thunder.zeros(1, 1); x.diag().fill_(1.0))

"the transition" of what? The better coverage by supporting in-place ops?

So I'm imagining that people will use this as soon as we have it and won't like us to regress on the support for their model.
Relative to #145, I wonder whether this is an alternative long-term solution or whether can we evolve it to a long-term solution or how else this fits into a long-term plan.

@crcrpar crcrpar force-pushed the crpa/inplace-support branch from 0165e23 to fe589dc Compare June 13, 2024 12:43
@crcrpar crcrpar force-pushed the crpa/inplace-support branch 2 times, most recently from 6d9a996 to c735f1f Compare June 14, 2024 07:04
@mruberry
Copy link
Collaborator

I think I'm a little confused. This allows for operations like add_ to appear in traces, and then a pass converts those operations to something like

c = add(a, b)
copy_(a, c)

?

What happens to the copy_ operations?

In general I think we need a design review of how to handle inplace operations before merging PRs related to them.

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing we are missing out here is that, in functionalization pass, we should identify input sources.

i.e. if inplace update is applied on non-intermediate buffer created in the trace, we shouldn't functionalize it yet.

If you try to run a batch_norm with your example, you'll notice that num_batches_tracked's update gets functionalized. So that's silent wrong result. We should instead throw a loud error for that case.

thunder/core/transform_common.py Outdated Show resolved Hide resolved
thunder/core/transform_common.py Outdated Show resolved Hide resolved
@jjsjann123
Copy link
Collaborator

I think I'm a little confused. This allows for operations like add_ to appear in traces, and then a pass converts those operations to something like

c = add(a, b)
copy_(a, c)

?

What happens to the copy_ operations?

In general I think we need a design review of how to handle inplace operations before merging PRs related to them.

I think @crcrpar is only handling inplace on intermediate that can be functionalized away. Which is already help (i.e. inplace activation in resnet)

@jjsjann123
Copy link
Collaborator

jjsjann123 commented Jun 14, 2024

The example I was referring to in my comment

import thunder
import torch

val = 5

def foo(flag):
    return flag

mod = torch.nn.modules.BatchNorm2d(4)
#mod.track_running_stats = None
mod.cuda()

jfoo = thunder.jit(mod)
 
a = torch.randn(2, 4, 5, 5, device="cuda")
 
print(mod.running_mean) 
print(mod.num_batches_tracked)
print(jfoo(a))  # tensor([1.])
print(mod.running_mean)
print(mod.num_batches_tracked)
 
orig_trace = thunder.last_traces(jfoo)[0]
traces = thunder.last_traces(jfoo)
 
print(f"===\n{traces=}")

I'm not asking you to necessarily support it in this PR, but we should error out instead of silent wrong result.

@crcrpar
Copy link
Collaborator Author

crcrpar commented Jun 15, 2024

If you try to run a batch_norm with your example, you'll notice that num_batches_tracked's update gets functionalized. So that's silent wrong result. We should instead throw a loud error for that case.

8889120 accessed it

@crcrpar crcrpar force-pushed the crpa/inplace-support branch from c17b4d8 to f6dc67b Compare June 15, 2024 07:04
@crcrpar
Copy link
Collaborator Author

crcrpar commented Jun 16, 2024

I think I'm a little confused. This allows for operations like add_ to appear in traces, and then a pass converts those operations to something like

c = add(a, b)
copy_(a, c)

?

What happens to the copy_ operations?

ltorch.add_(a, b) is expressed with two subsymbols, they are t_something = add(a, b) and copy_(t_something, a) then the latter is removed by functionalization.

@crcrpar crcrpar force-pushed the crpa/inplace-support branch from aeadbb5 to d9a1275 Compare June 16, 2024 16:24
@IvanYashchuk
Copy link
Collaborator

What happens to the copy_ operations? (#584 (comment))

If it's in the user script it's not traced today and it's not traced in this PR. The only way to introduce copy_ is from within Thunder.

@IvanYashchuk
Copy link
Collaborator

According to #584 (comment) num_batches_tracked of BatchNorm is correctly updated in-place with an nvFuser region. I hope nvFuser's fusion performance is not suddenly destroyed by this. We've got one benchmark case to try it

def test_batch_norm(benchmark, executor: Callable, compute_type: ComputeType):

Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR does two things at once to demonstrate the usability of in-place tracing and generating functional code:

  1. In-place operations like abs_, add_, etc. are added to thunder.torch so that Thunder's Python Interpreter can recognize corresponding PyTorch operations and put them into the initial Thunder trace.
  2. Transform in-place operations on intermediates into out-of-place variants. This is an operator-level transform, there's no interaction between ops. In-place on views is handled in a separate pull request (Partially support in-place ops and tensor aliases #597).

We need part 1, an alternative to part 2 could be freezing the order of in-place operations and let the PyTorch executor to execute them. There could be other good alternatives and the part 2 can easily be disabled if needed in the future.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the changes to this file look good. There's no other obvious way we can support reading in-place PyTorch operations from user code with Thunder's Python Interpreter.
Independent of how the in-place operations are treated later we need to get them into the initial trace first.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noticed that we might want to add some distributed calls such as torch.distributed.all_reduce, torch.distributed.all_gather_into_tensor, and torch.distributed.reduce_scatter_tensor later.

thunder/tests/opinfos.py Show resolved Hide resolved

num_orig_bsyms = len(trace.bound_symbols)

# note(crcrpar): The path looks neat but it does not work for a trace
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this called? Is it still needed now that functionalization is moved into thunder.jit to be applied after the interpreter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, great catch. reverted the changes in this file

thunder/core/jit_ext.py Outdated Show resolved Hide resolved
thunder/core/transform_common.py Outdated Show resolved Hide resolved
"""Functionalize in-place ops in ``computation_trace``.

In thunder, an in-place is an out-of-place or functional op followed by :func:`~thunder.core.prims.copy_`.
This function replaces such in-place ops with out-of-place ops.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"... only if the in-place argument is intermediate to the trace", right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see later "functionalization is not applied, if any of an in-place op's arguments are computation_trace.args or computation_trace.kwargs."

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we error / warn in that case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that BatchNorm's num_batches_tracked tensor update is expressed as ltorch.add_(num_batches_tracked, 1) and the tensor is an arg. so this makes sense to me. also, if one or more of args & kwargs are updated in an in-place manner, then I guess there's some intention so I'm not inclined to ban such cases

thunder/core/transform_common.py Outdated Show resolved Hide resolved
@crcrpar crcrpar force-pushed the crpa/inplace-support branch from d9a1275 to 4238e1c Compare June 17, 2024 12:03
new_bsyms.append(new_bsym)
continue
functional_sym_name = new_bsym.sym.id.split(".")[-1][:-1]
check(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: should this be a check, or should we rather just skip?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If is_functionalizable checks for an existing mapping. we shouldn't need this check here.

swap_map[variableify(copy_return)] = copy_from
new_bsym.subsymbols = new_bsym.subsymbols[:-1]
new_bsym = new_bsym.from_bsym_swap_proxies(swap_map)
functional_sym: Symbol = getattr(thunder.torch, functional_sym_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inconsistency here looks bad.

It's fine to choose to only functionalize torch.xxx_ to torch.xxx. But the is_functionalizable thing should have the same logic.


def is_functionalizable(bsym: BoundSymbol) -> bool:
"""Has `OpTags.IN_PLACE` and its args are NOT ``computation_trace.args`` nor ``computation_trace.kwargs``."""
return (
Copy link
Collaborator

@jjsjann123 jjsjann123 Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we care about having IN-PLACE tag or not here. since the logic below for replacing doesn't take any consideration like that.

I feel the logic here should just check for torch.xxx_ and see if there is a torch.xxx

If we want to move forward with the in_place tag here, maybe we should maintain a map from in_place to out_of_place function, instead of relying on the trailing underscore.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now here's _inplace_to_out_of_place

bsym.sym.tags
and prims.OpTags.IN_PLACE in bsym.sym.tags
and bsym.subsymbols
and bsym.subsymbols[-1].sym.id == prims.PrimIDs.COPY_
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we certainly should drop the subsymbols check. This is irrelevant from how this PR is handling functionalization.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is irrelevant from how this PR is handling functionalization.

why is it irrelevant? Currently in-place bsyms have out-of-place and copy as their subsymbols so I think it fair to check the last sub bound symbol is copy.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how can we tell an appropriate output tensor proxy if a bsym doesn't have a copy_ as its last sub bsym, while avoiding having a lot of new tensor proxy names in a trace?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops. sorry I read your implementation wrong earlier... I thought we are doing a blind torch.xxx_ to torch.xxx replacement but that's not the case. You actually are only looking at the last subsymbol and replacing that one entry only.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That feels a bit restricted... But a first step is still better then nothing and I'll stop nitpicking on that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That feels a bit restricted... But a first step is still better then nothing and I'll stop nitpicking on that.

how would it be a bit restricted compared to a blind torch.foo_ to torch.foo replacement?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no that's not what I meant.

By That feels a bit restricted, I'm referring to the alternative that we functionalize directly at the subsymbol prim.copy_ level. But again we don't have to do that in this PR.

@crcrpar
Copy link
Collaborator Author

crcrpar commented Jun 17, 2024

I now think OpTags.IN_PLACE would not work quite well given that there are some functions that takes inplace as their argument such as torch.relu whose ltorch def is

@torchsymbol(torch.relu, torch.nn.functional.relu, id="torch.relu", is_method=True)
def relu(a: TensorLike, /, inplace: bool = False) -> TensorLike:
utils.check(not inplace, lambda: f"relu only supports inplace=False", exception_type=NotImplementedError)
return where(a > 0, a, 0)
which I cannot pass this tag separately

@jjsjann123
Copy link
Collaborator

I now think OpTags.IN_PLACE would not work quite well given that there are some functions that takes inplace as their argument such as torch.relu whose ltorch def is

@torchsymbol(torch.relu, torch.nn.functional.relu, id="torch.relu", is_method=True)
def relu(a: TensorLike, /, inplace: bool = False) -> TensorLike:
utils.check(not inplace, lambda: f"relu only supports inplace=False", exception_type=NotImplementedError)
return where(a > 0, a, 0)

which I cannot pass this tag separately

The IN_PLACE tag is used more like a MAYBE_INPLACE in the implementation. Maybe switching to that?

@crcrpar crcrpar force-pushed the crpa/inplace-support branch from f12bf41 to d6fe101 Compare June 18, 2024 03:14
Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My earlier concern has been addressed. Stamping.

thunder/torch/__init__.py Show resolved Hide resolved
utils.check(
dtypes.is_float_dtype(a.dtype),
lambda: f"hardswish only supports floating point dtypes, got {a.dtype}",
exception_type=ValueError,
)
return a * relu6(a + 3) / 6
out = a * relu6(a + 3) / 6
if inplace:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to myself, this needs to be made a static constraint.
linking PR #613

@crcrpar crcrpar force-pushed the crpa/inplace-support branch from 300256f to bd36039 Compare June 18, 2024 12:36
Comment on lines 17 to 23
def sample_generator_wrapper(sample_generator, is_silu: bool = False):

def f(*args, **kwargs):
for sample in sample_generator(*args, **kwargs):
if not is_silu:
yield SampleInput(*(list(sample.args) + [True]), **sample.kwargs)
else:
# silu treats `inplace` as a kwarg
# ref: https://github.com/Lightning-AI/lightning-thunder/commit/335d84c89
new_kwargs = {"inplace": True}
if sample.kwargs:
new_kwargs.update(sample.kwargs)
yield SampleInput(*sample.args, **new_kwargs)

return f
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rel: #615

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@crcrpar crcrpar force-pushed the crpa/inplace-support branch from 4fb5b9a to d4402b0 Compare June 20, 2024 07:29
@crcrpar crcrpar marked this pull request as ready for review June 20, 2024 07:30
@crcrpar crcrpar requested review from lantiga, robieta and t-vi as code owners June 20, 2024 07:30
@nikitaved nikitaved self-requested a review June 20, 2024 11:30
Comment on lines +387 to +388
def is_functionalizable(bsym: BoundSymbol) -> bool:
"""Has `OpTags.IN_PLACE` and its args are NOT ``computation_trace.args`` nor ``computation_trace.kwargs``."""
Copy link
Contributor

@nikitaved nikitaved Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is IN_PLACE actually used here? EDIT: yes, implicitly through being added to the map.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it also true that the trace args/kwargs are also being checked implicitly somewhere outside?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +418 to +421
intermediate_trace = from_trace(computation_trace)
intermediate_trace.bound_symbols = bsyms[:]
intermediate_trace.set_provenance(TraceProvenance("Intermediate trace of `functionalize_inplace_ops`"))
del bsyms
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can't we just do intermediate_tensors.bound_symbols = bsyms?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. I didn't want to reuse bsyms

Comment on lines +439 to +448
copy_bsym = bsym.subsymbols[-1]
copy_return = copy_bsym.flat_proxy_outs[0]
copy_from = copy_bsym.flat_proxy_args[0]
copy_to = copy_bsym.flat_proxy_args[1]
if copy_to in trace_args_set:
new_bsyms.append(new_bsym)
else:
swap_map[variableify(copy_return)] = copy_from
new_bsym.subsymbols = new_bsym.subsymbols[:-1]
new_bsym = new_bsym.from_bsym_swap_proxies(swap_map)
Copy link
Contributor

@nikitaved nikitaved Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic looks similar to what step 1 is doing. Couldn't they be merged? It seems like the whole thing could be done in a single pass?

Comment on lines +455 to +456
if optional_inplace_arg_index > -1:
flat_args[optional_inplace_arg_index] = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably needs a comment.

_call_ctx=new_bsym._call_ctx,
)
new_bsyms.append(new_functional_bsym)
bsym_inplace_to_functional[new_bsym] = new_functional_bsym
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bsym_inplace_to_functional is never read from?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right but I once tried to register this as an attribute of provenance at L473

@@ -503,6 +509,9 @@ def get_computation_and_inputs(*args, **kwargs):

prologue_traces = [prologue_trc]
computation_traces = [computation_trc]
if not compile_options.get("skip_inplace_functionalization", False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Longer term, I wonder if we should have a set of default transformations and this be one of them, but for now it is OK.

@@ -72,7 +72,7 @@ def resolve_method(id: Any, *args, **kwargs) -> None | Callable:
# ctx.get_method throws an AttributeError when the context does not have the requested attribute, except
# for the prims language context, which always throws a ValueError
method: Callable = ctx.get_method(id, *args, **kwargs)
except (AttributeError, ValueError) as e:
except (AttributeError, ValueError):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch!

"""Functionalize in-place ops in ``computation_trace``.

In thunder, an in-place is an out-of-place or functional op followed by :func:`~thunder.core.prims.copy_`.
This function replaces such in-place ops with out-of-place ops.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we error / warn in that case?

bsyms.append(new_bsym)

intermediate_trace = from_trace(computation_trace)
intermediate_trace.bound_symbols = bsyms[:]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: do we need to copy if we del below?

@@ -72,7 +72,7 @@ def resolve_method(id: Any, *args, **kwargs) -> None | Callable:
# ctx.get_method throws an AttributeError when the context does not have the requested attribute, except
# for the prims language context, which always throws a ValueError
method: Callable = ctx.get_method(id, *args, **kwargs)
except (AttributeError, ValueError) as e:
except (AttributeError, ValueError):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch

@@ -503,6 +509,9 @@ def get_computation_and_inputs(*args, **kwargs):

prologue_traces = [prologue_trc]
computation_traces = [computation_trc]
if not compile_options.get("skip_inplace_functionalization", False):
computation_traces.extend(functionalize_inplace_ops(computation_trace=computation_trc))
computation_trc = computation_traces[-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Longer term, I wonder if this should be a "default transform", but maybe it is important that this goes first and so it is tricky with the timing.

In thunder, an in-place is an out-of-place or functional op followed by :func:`~thunder.core.prims.copy_`.
This function replaces such in-place ops with out-of-place ops.
Note that functionalization is not applied, if any of an in-place op's arguments are
``computation_trace.args`` or ``computation_trace.kwargs``.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder what we should do in these cases, though, warn, error?

bsyms.append(new_bsym)

intermediate_trace = from_trace(computation_trace)
intermediate_trace.bound_symbols = bsyms[:]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I don't think we strictly need the copy here.

@@ -0,0 +1,124 @@
from __future__ import annotations
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great to have tests for it!

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supergood! We'll need to make things safer, but this is an awesome start.
Thank you @IvanYashchuk @jjsjann123 @nikitaved @mruberry for your reviews and comments.

@t-vi t-vi merged commit e28ea5e into main Jun 20, 2024
39 checks passed
@t-vi t-vi deleted the crpa/inplace-support branch June 20, 2024 13:26
Comment on lines +1252 to +1254
@torchsymbol(torch.Tensor.abs_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def abs_(a: NumberLike | TensorLike, /) -> Number | TensorLike:
return prims.copy_(abs(a), a)
Copy link
Contributor

@nikitaved nikitaved Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe adding a decorator for an out-of-place op to register an in-place counterpart could be cleaner if realizable? There we could also populate the map if needed.

Comment on lines +46 to +48
for op in opinfos:
if not (op.op in _functional_to_inplace or op.op in _functional_to_functional_with_inplace_arg):
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe in some future we would like to have a flag to test in-place ops in opinfos for out-of-place OpInfos?

Comment on lines +64 to +66
sample_input_generator=(
op.sample_input_generator if op.name != "masked_fill" else inplace_masked_fill_sample_generator
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: seems this motivates the comment from above. We could set test_in_place=False for masked_fill and create a separate OpInfo for the in-place variant.

Comment on lines +93 to +95
@ops(_inplace_opinfos, supported_dtypes=(dtypes.float32,))
def test_functionalization(op: OpInfo, device: str, dtype: dtypes.dtype, executor, _):
import thunder
Copy link
Contributor

@nikitaved nikitaved Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we might be missing some test with a sequence of in-places ops? Just to be sure that the bsym replacement logic is sound across multiple symbols, not just a single one. Or is it not for this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we might be missing some test with a sequence of in-places ops?

yes, this pr doesn't have such tests.

Just to be sure that the bsym replacement logic is sound across multiple symbols, not just a single one. Or is it not for this PR?

Locally I've been using the following snippet so I hope the functionalization works. I just didn't have a clear picture of designing tests with a sequence of in-place ops.

import torch
import thunder


def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    a += b
    c = torch.exp(a)
    d = torch.tanh(b)

    c += d

    d.div_(a)

    e = c + d
    f = torch.nn.functional.relu(e, inplace=True)
    g = a + b
    return f, c, d, torch.relu_(g)


def main():
    a, b = [torch.randn((2, 2), device="cuda", requires_grad=False) for _ in range(2)]
    a_, b_ = a.clone().detach(), b.clone().detach()
    jit_f = thunder.jit(f, executors=[thunder.executors.get_torch_executor()])
    c, d, e, g = jit_f(a, b)
    c_, d_, e_, g_ = f(a_, b_)

    traces = thunder.last_traces(jit_f)
    print(traces[-1])

    torch.testing.assert_close(d, d_)
    torch.testing.assert_close(c, c_)
    torch.testing.assert_close(e, e_)
    torch.testing.assert_close(g, g_)


if __name__ == "__main__":
    main()

Copy link
Contributor

@nikitaved nikitaved Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add some handpicked tests for sure. We can always throw some line duplications for numerically more stable ops :)

@apaz-cli
Copy link
Contributor

Glorious.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants