-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Functionalize in-place ops #584
Conversation
Awesome stuff. The fundamental questions I have about this are
|
could you give me some examples of silent errors? I'm not quite following what they would be.
"the transition" of what? The better coverage by supporting in-place ops? |
Essentially things that rely on aliases (
So I'm imagining that people will use this as soon as we have it and won't like us to regress on the support for their model. |
0165e23
to
fe589dc
Compare
6d9a996
to
c735f1f
Compare
I think I'm a little confused. This allows for operations like
? What happens to the In general I think we need a design review of how to handle inplace operations before merging PRs related to them. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing we are missing out here is that, in functionalization pass, we should identify input sources.
i.e. if inplace update is applied on non-intermediate buffer created in the trace, we shouldn't functionalize it yet.
If you try to run a batch_norm with your example, you'll notice that num_batches_tracked
's update gets functionalized. So that's silent wrong result. We should instead throw a loud error for that case.
I think @crcrpar is only handling inplace on intermediate that can be functionalized away. Which is already help (i.e. inplace activation in resnet) |
The example I was referring to in my comment import thunder
import torch
val = 5
def foo(flag):
return flag
mod = torch.nn.modules.BatchNorm2d(4)
#mod.track_running_stats = None
mod.cuda()
jfoo = thunder.jit(mod)
a = torch.randn(2, 4, 5, 5, device="cuda")
print(mod.running_mean)
print(mod.num_batches_tracked)
print(jfoo(a)) # tensor([1.])
print(mod.running_mean)
print(mod.num_batches_tracked)
orig_trace = thunder.last_traces(jfoo)[0]
traces = thunder.last_traces(jfoo)
print(f"===\n{traces=}") I'm not asking you to necessarily support it in this PR, but we should error out instead of silent wrong result. |
8889120 accessed it |
c17b4d8
to
f6dc67b
Compare
|
aeadbb5
to
d9a1275
Compare
If it's in the user script it's not traced today and it's not traced in this PR. The only way to introduce |
According to #584 (comment)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR does two things at once to demonstrate the usability of in-place tracing and generating functional code:
- In-place operations like
abs_
,add_
, etc. are added tothunder.torch
so that Thunder's Python Interpreter can recognize corresponding PyTorch operations and put them into the initial Thunder trace. - Transform in-place operations on intermediates into out-of-place variants. This is an operator-level transform, there's no interaction between ops. In-place on views is handled in a separate pull request (Partially support in-place ops and tensor aliases #597).
We need part 1, an alternative to part 2 could be freezing the order of in-place operations and let the PyTorch executor to execute them. There could be other good alternatives and the part 2 can easily be disabled if needed in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of the changes to this file look good. There's no other obvious way we can support reading in-place PyTorch operations from user code with Thunder's Python Interpreter.
Independent of how the in-place operations are treated later we need to get them into the initial trace first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noticed that we might want to add some distributed calls such as torch.distributed.all_reduce
, torch.distributed.all_gather_into_tensor
, and torch.distributed.reduce_scatter_tensor
later.
thunder/core/jit_ext.py
Outdated
|
||
num_orig_bsyms = len(trace.bound_symbols) | ||
|
||
# note(crcrpar): The path looks neat but it does not work for a trace |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this called? Is it still needed now that functionalization is moved into thunder.jit
to be applied after the interpreter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, great catch. reverted the changes in this file
"""Functionalize in-place ops in ``computation_trace``. | ||
|
||
In thunder, an in-place is an out-of-place or functional op followed by :func:`~thunder.core.prims.copy_`. | ||
This function replaces such in-place ops with out-of-place ops. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"... only if the in-place argument is intermediate to the trace", right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see later "functionalization is not applied, if any of an in-place op's arguments are computation_trace.args
or computation_trace.kwargs
."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we error / warn in that case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems that BatchNorm's num_batches_tracked
tensor update is expressed as ltorch.add_(num_batches_tracked, 1)
and the tensor is an arg. so this makes sense to me. also, if one or more of args & kwargs are updated in an in-place manner, then I guess there's some intention so I'm not inclined to ban such cases
d9a1275
to
4238e1c
Compare
thunder/core/transform_common.py
Outdated
new_bsyms.append(new_bsym) | ||
continue | ||
functional_sym_name = new_bsym.sym.id.split(".")[-1][:-1] | ||
check( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: should this be a check, or should we rather just skip?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If is_functionalizable
checks for an existing mapping. we shouldn't need this check here.
thunder/core/transform_common.py
Outdated
swap_map[variableify(copy_return)] = copy_from | ||
new_bsym.subsymbols = new_bsym.subsymbols[:-1] | ||
new_bsym = new_bsym.from_bsym_swap_proxies(swap_map) | ||
functional_sym: Symbol = getattr(thunder.torch, functional_sym_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The inconsistency here looks bad.
It's fine to choose to only functionalize torch.xxx_
to torch.xxx
. But the is_functionalizable
thing should have the same logic.
|
||
def is_functionalizable(bsym: BoundSymbol) -> bool: | ||
"""Has `OpTags.IN_PLACE` and its args are NOT ``computation_trace.args`` nor ``computation_trace.kwargs``.""" | ||
return ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we care about having IN-PLACE
tag or not here. since the logic below for replacing doesn't take any consideration like that.
I feel the logic here should just check for torch.xxx_
and see if there is a torch.xxx
If we want to move forward with the in_place
tag here, maybe we should maintain a map from in_place
to out_of_place
function, instead of relying on the trailing underscore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now here's _inplace_to_out_of_place
bsym.sym.tags | ||
and prims.OpTags.IN_PLACE in bsym.sym.tags | ||
and bsym.subsymbols | ||
and bsym.subsymbols[-1].sym.id == prims.PrimIDs.COPY_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we certainly should drop the subsymbols check. This is irrelevant from how this PR is handling functionalization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is irrelevant from how this PR is handling functionalization.
why is it irrelevant? Currently in-place bsyms have out-of-place and copy as their subsymbols so I think it fair to check the last sub bound symbol is copy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how can we tell an appropriate output tensor proxy if a bsym doesn't have a copy_ as its last sub bsym, while avoiding having a lot of new tensor proxy names in a trace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops. sorry I read your implementation wrong earlier... I thought we are doing a blind torch.xxx_
to torch.xxx
replacement but that's not the case. You actually are only looking at the last subsymbol and replacing that one entry only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That feels a bit restricted... But a first step is still better then nothing and I'll stop nitpicking on that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That feels a bit restricted... But a first step is still better then nothing and I'll stop nitpicking on that.
how would it be a bit restricted compared to a blind torch.foo_
to torch.foo
replacement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no that's not what I meant.
By That feels a bit restricted
, I'm referring to the alternative that we functionalize directly at the subsymbol prim.copy_
level. But again we don't have to do that in this PR.
I now think lightning-thunder/thunder/torch/__init__.py Lines 1458 to 1462 in 8309fc0
|
The |
f12bf41
to
d6fe101
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My earlier concern has been addressed. Stamping.
utils.check( | ||
dtypes.is_float_dtype(a.dtype), | ||
lambda: f"hardswish only supports floating point dtypes, got {a.dtype}", | ||
exception_type=ValueError, | ||
) | ||
return a * relu6(a + 3) / 6 | ||
out = a * relu6(a + 3) / 6 | ||
if inplace: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note to myself, this needs to be made a static constraint.
linking PR #613
300256f
to
bd36039
Compare
def sample_generator_wrapper(sample_generator, is_silu: bool = False): | ||
|
||
def f(*args, **kwargs): | ||
for sample in sample_generator(*args, **kwargs): | ||
if not is_silu: | ||
yield SampleInput(*(list(sample.args) + [True]), **sample.kwargs) | ||
else: | ||
# silu treats `inplace` as a kwarg | ||
# ref: https://github.com/Lightning-AI/lightning-thunder/commit/335d84c89 | ||
new_kwargs = {"inplace": True} | ||
if sample.kwargs: | ||
new_kwargs.update(sample.kwargs) | ||
yield SampleInput(*sample.args, **new_kwargs) | ||
|
||
return f |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rel: #615
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
4fb5b9a
to
d4402b0
Compare
def is_functionalizable(bsym: BoundSymbol) -> bool: | ||
"""Has `OpTags.IN_PLACE` and its args are NOT ``computation_trace.args`` nor ``computation_trace.kwargs``.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: is IN_PLACE
actually used here? EDIT: yes, implicitly through being added to the map.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it also true that the trace args/kwargs are also being checked implicitly somewhere outside?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
intermediate_trace = from_trace(computation_trace) | ||
intermediate_trace.bound_symbols = bsyms[:] | ||
intermediate_trace.set_provenance(TraceProvenance("Intermediate trace of `functionalize_inplace_ops`")) | ||
del bsyms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can't we just do intermediate_tensors.bound_symbols = bsyms
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. I didn't want to reuse bsyms
copy_bsym = bsym.subsymbols[-1] | ||
copy_return = copy_bsym.flat_proxy_outs[0] | ||
copy_from = copy_bsym.flat_proxy_args[0] | ||
copy_to = copy_bsym.flat_proxy_args[1] | ||
if copy_to in trace_args_set: | ||
new_bsyms.append(new_bsym) | ||
else: | ||
swap_map[variableify(copy_return)] = copy_from | ||
new_bsym.subsymbols = new_bsym.subsymbols[:-1] | ||
new_bsym = new_bsym.from_bsym_swap_proxies(swap_map) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic looks similar to what step 1 is doing. Couldn't they be merged? It seems like the whole thing could be done in a single pass?
if optional_inplace_arg_index > -1: | ||
flat_args[optional_inplace_arg_index] = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This probably needs a comment.
_call_ctx=new_bsym._call_ctx, | ||
) | ||
new_bsyms.append(new_functional_bsym) | ||
bsym_inplace_to_functional[new_bsym] = new_functional_bsym |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bsym_inplace_to_functional
is never read from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right but I once tried to register this as an attribute of provenance at L473
@@ -503,6 +509,9 @@ def get_computation_and_inputs(*args, **kwargs): | |||
|
|||
prologue_traces = [prologue_trc] | |||
computation_traces = [computation_trc] | |||
if not compile_options.get("skip_inplace_functionalization", False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Longer term, I wonder if we should have a set of default transformations and this be one of them, but for now it is OK.
@@ -72,7 +72,7 @@ def resolve_method(id: Any, *args, **kwargs) -> None | Callable: | |||
# ctx.get_method throws an AttributeError when the context does not have the requested attribute, except | |||
# for the prims language context, which always throws a ValueError | |||
method: Callable = ctx.get_method(id, *args, **kwargs) | |||
except (AttributeError, ValueError) as e: | |||
except (AttributeError, ValueError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch!
"""Functionalize in-place ops in ``computation_trace``. | ||
|
||
In thunder, an in-place is an out-of-place or functional op followed by :func:`~thunder.core.prims.copy_`. | ||
This function replaces such in-place ops with out-of-place ops. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we error / warn in that case?
bsyms.append(new_bsym) | ||
|
||
intermediate_trace = from_trace(computation_trace) | ||
intermediate_trace.bound_symbols = bsyms[:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: do we need to copy if we del below?
@@ -72,7 +72,7 @@ def resolve_method(id: Any, *args, **kwargs) -> None | Callable: | |||
# ctx.get_method throws an AttributeError when the context does not have the requested attribute, except | |||
# for the prims language context, which always throws a ValueError | |||
method: Callable = ctx.get_method(id, *args, **kwargs) | |||
except (AttributeError, ValueError) as e: | |||
except (AttributeError, ValueError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch
@@ -503,6 +509,9 @@ def get_computation_and_inputs(*args, **kwargs): | |||
|
|||
prologue_traces = [prologue_trc] | |||
computation_traces = [computation_trc] | |||
if not compile_options.get("skip_inplace_functionalization", False): | |||
computation_traces.extend(functionalize_inplace_ops(computation_trace=computation_trc)) | |||
computation_trc = computation_traces[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Longer term, I wonder if this should be a "default transform", but maybe it is important that this goes first and so it is tricky with the timing.
In thunder, an in-place is an out-of-place or functional op followed by :func:`~thunder.core.prims.copy_`. | ||
This function replaces such in-place ops with out-of-place ops. | ||
Note that functionalization is not applied, if any of an in-place op's arguments are | ||
``computation_trace.args`` or ``computation_trace.kwargs``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder what we should do in these cases, though, warn, error?
bsyms.append(new_bsym) | ||
|
||
intermediate_trace = from_trace(computation_trace) | ||
intermediate_trace.bound_symbols = bsyms[:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I don't think we strictly need the copy here.
@@ -0,0 +1,124 @@ | |||
from __future__ import annotations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great to have tests for it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Supergood! We'll need to make things safer, but this is an awesome start.
Thank you @IvanYashchuk @jjsjann123 @nikitaved @mruberry for your reviews and comments.
@torchsymbol(torch.Tensor.abs_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) | ||
def abs_(a: NumberLike | TensorLike, /) -> Number | TensorLike: | ||
return prims.copy_(abs(a), a) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe adding a decorator for an out-of-place op to register an in-place counterpart could be cleaner if realizable? There we could also populate the map if needed.
for op in opinfos: | ||
if not (op.op in _functional_to_inplace or op.op in _functional_to_functional_with_inplace_arg): | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe in some future we would like to have a flag to test in-place ops in opinfos for out-of-place OpInfos?
sample_input_generator=( | ||
op.sample_input_generator if op.name != "masked_fill" else inplace_masked_fill_sample_generator | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: seems this motivates the comment from above. We could set test_in_place=False
for masked_fill
and create a separate OpInfo for the in-place variant.
@ops(_inplace_opinfos, supported_dtypes=(dtypes.float32,)) | ||
def test_functionalization(op: OpInfo, device: str, dtype: dtypes.dtype, executor, _): | ||
import thunder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we might be missing some test with a sequence of in-places ops? Just to be sure that the bsym replacement logic is sound across multiple symbols, not just a single one. Or is it not for this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we might be missing some test with a sequence of in-places ops?
yes, this pr doesn't have such tests.
Just to be sure that the bsym replacement logic is sound across multiple symbols, not just a single one. Or is it not for this PR?
Locally I've been using the following snippet so I hope the functionalization works. I just didn't have a clear picture of designing tests with a sequence of in-place ops.
import torch
import thunder
def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
a += b
c = torch.exp(a)
d = torch.tanh(b)
c += d
d.div_(a)
e = c + d
f = torch.nn.functional.relu(e, inplace=True)
g = a + b
return f, c, d, torch.relu_(g)
def main():
a, b = [torch.randn((2, 2), device="cuda", requires_grad=False) for _ in range(2)]
a_, b_ = a.clone().detach(), b.clone().detach()
jit_f = thunder.jit(f, executors=[thunder.executors.get_torch_executor()])
c, d, e, g = jit_f(a, b)
c_, d_, e_, g_ = f(a_, b_)
traces = thunder.last_traces(jit_f)
print(traces[-1])
torch.testing.assert_close(d, d_)
torch.testing.assert_close(c, c_)
torch.testing.assert_close(e, e_)
torch.testing.assert_close(g, g_)
if __name__ == "__main__":
main()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could add some handpicked tests for sure. We can always throw some line duplications for numerically more stable ops :)
Glorious. |
What does this PR do?
Express in-place ops using their out-of-place counterpart with following
prims.copy_
. e.g.a.add_(b)
asprims.copy_(prims.add(a, b), a)
, then let the added transform functionalize the trace by removing the trailing copy amd updating the signature.Let’s say we have
t.exp_()
in a script andt
is used afterwards, thunder translates it intot_out = ltorch.exp_(t)
. This bound symbol has two sub bound symbols:t0 = ltorch.exp(t)
andt_out = prims.copy_(t0, t)
. The functionalization removes the copy, and replacesltorch.exp_(t)
witht0 = ltorch.exp(t)
andt
uses afterexp_
witht0
.The covered ops are ones that either (a) have in-place variant such as
torch.exp
andtorch.add
or (b) haveinplace
as one of their args such astorch.nn.functional.relu
.So this would not cover the entire #145 broadly, nor take aliases into considerations.