-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Functionalize in-place ops #584
Changes from all commits
c1a13fa
05afe9d
351149a
c1d5234
0044f6d
6355d34
e348150
d7dca4d
6093ffe
532abbb
82da20c
cf814f1
ad6d84c
acdab60
b7ad1e2
203c396
54f9fa4
8a04a5e
e2a5c12
c7d13d0
5afcbbe
367711f
1447071
8f6f23b
08479ef
04808b5
05605f7
f88e6d4
ac5abe0
494e08f
efea02d
828ba29
2040d18
acd7655
fa9756f
bae091e
3f03f74
d4402b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,13 @@ | |
import thunder.core.prims as prims | ||
import thunder.core.dtypes as dtypes | ||
import thunder.core.devices as devices | ||
from thunder.core.transform_common import dce, EarlyTransform, AdditionalTransform, PostOptimizationTransform | ||
from thunder.core.transform_common import ( | ||
dce, | ||
EarlyTransform, | ||
AdditionalTransform, | ||
PostOptimizationTransform, | ||
functionalize_inplace_ops, | ||
) | ||
from thunder.common import ( | ||
CompileData, | ||
CompileStats, | ||
|
@@ -503,6 +509,9 @@ def get_computation_and_inputs(*args, **kwargs): | |
|
||
prologue_traces = [prologue_trc] | ||
computation_traces = [computation_trc] | ||
if not compile_options.get("skip_inplace_functionalization", False): | ||
computation_traces.extend(functionalize_inplace_ops(computation_trace=computation_trc)) | ||
computation_trc = computation_traces[-1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Longer term, I wonder if this should be a "default transform", but maybe it is important that this goes first and so it is tricky with the timing. |
||
|
||
if epilogue_trc is not None: | ||
epilogue_traces = [epilogue_trc] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -72,7 +72,7 @@ def resolve_method(id: Any, *args, **kwargs) -> None | Callable: | |
# ctx.get_method throws an AttributeError when the context does not have the requested attribute, except | ||
# for the prims language context, which always throws a ValueError | ||
method: Callable = ctx.get_method(id, *args, **kwargs) | ||
except (AttributeError, ValueError) as e: | ||
except (AttributeError, ValueError): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great catch! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great catch |
||
return None | ||
return method | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,23 @@ | ||
from __future__ import annotations | ||
import time | ||
from typing import Any | ||
from typing import TYPE_CHECKING | ||
from abc import ABC, abstractmethod | ||
from collections.abc import Sequence | ||
from itertools import filterfalse | ||
from functools import partial | ||
|
||
import thunder.core.prims as prims | ||
from thunder.core.baseutils import BoundSymbolInterface | ||
from thunder.core.proxies import Proxy, variableify, Variable | ||
from thunder.core.pytree import tree_flatten, tree_map | ||
from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy | ||
from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten | ||
from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, has_tags | ||
from thunder.core.trace import from_trace, TraceProvenance, TraceCtx as Trace | ||
from thunder.core.utils import ProxyDict, producers, check | ||
|
||
if TYPE_CHECKING: | ||
from thunder.core.proxies import ProxyInterface | ||
from thunder.core.symbol import Symbol, VariableInterface | ||
|
||
|
||
# | ||
# Common optimization and transform passes | ||
|
@@ -363,3 +368,107 @@ class PostOptimizationTransform(Transform, ABC): | |
@abstractmethod | ||
def transform_trace(self, computation_trace: Trace, **kwargs): | ||
pass | ||
|
||
|
||
def functionalize_inplace_ops(computation_trace: Trace) -> list[Trace]: | ||
"""Functionalize in-place ops in ``computation_trace``. | ||
|
||
In thunder, an in-place is an out-of-place or functional op followed by :func:`~thunder.core.prims.copy_`. | ||
This function replaces such in-place ops with out-of-place ops. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "... only if the in-place argument is intermediate to the trace", right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I see later "functionalization is not applied, if any of an in-place op's arguments are There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we error / warn in that case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems that BatchNorm's |
||
Note that functionalization is not applied, if any of an in-place op's arguments are | ||
``computation_trace.args`` or ``computation_trace.kwargs``. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder what we should do in these cases, though, warn, error? |
||
|
||
For example, :func:`thunder.torch.add_` is represented as a :class:`thunder.core.symbol.BoundSymbol` | ||
whose `subsymbols` are :func:`thunder.torch.add` and :func:`thunder.core.prims.copy_`. This function | ||
replaces it with a :class:`~thunder.core.symbol.BoundSymbol` of :func:`~thunder.torch.add`. | ||
""" | ||
import thunder.torch | ||
|
||
def is_functionalizable(bsym: BoundSymbol) -> bool: | ||
"""Has `OpTags.IN_PLACE` and its args are NOT ``computation_trace.args`` nor ``computation_trace.kwargs``.""" | ||
Comment on lines
+387
to
+388
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it also true that the trace args/kwargs are also being checked implicitly somewhere outside? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
return ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If we want to move forward with the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. now here's |
||
bsym.sym in thunder.torch._inplace_to_out_of_place | ||
and bsym.subsymbols | ||
and bsym.subsymbols[-1].sym.id == prims.PrimIDs.COPY_ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we certainly should drop the subsymbols check. This is irrelevant from how this PR is handling functionalization. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
why is it irrelevant? Currently in-place bsyms have out-of-place and copy as their subsymbols so I think it fair to check the last sub bound symbol is copy. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how can we tell an appropriate output tensor proxy if a bsym doesn't have a copy_ as its last sub bsym, while avoiding having a lot of new tensor proxy names in a trace? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oops. sorry I read your implementation wrong earlier... I thought we are doing a blind There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That feels a bit restricted... But a first step is still better then nothing and I'll stop nitpicking on that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
how would it be a bit restricted compared to a blind There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no that's not what I meant. By |
||
) | ||
|
||
if not any(is_functionalizable(bsym) for bsym in computation_trace.bound_symbols): | ||
nikitaved marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return [] | ||
|
||
# Step 1: return the tensors returned from `prims.copy_` as possible not the args for clarity. | ||
bsym: BoundSymbol | ||
swap_map: dict[VariableInterface, ProxyInterface] = {} | ||
bsyms: list[BoundSymbol] = [] | ||
for bsym in computation_trace.bound_symbols: | ||
new_bsym = bsym.from_bsym_swap_proxies(swap_map) | ||
|
||
# in-place functionalizable ops has `prims.copy_` as the last subsymbol. | ||
if not is_functionalizable(new_bsym): | ||
bsyms.append(new_bsym) | ||
continue | ||
|
||
copy_bsym = bsym.subsymbols[-1] | ||
copy_out = copy_bsym.flat_proxy_outs[0] | ||
copy_dst = copy_bsym.flat_proxy_args[1] | ||
swap_map[variableify(copy_dst)] = copy_out | ||
# make sure an in-place bsym returns `prims.copy_` output | ||
new_bsym = new_bsym.from_bsym_swap_proxies(swap_map, skip_inputs=True, skip_subsymbols=True) | ||
bsyms.append(new_bsym) | ||
|
||
intermediate_trace = from_trace(computation_trace) | ||
intermediate_trace.bound_symbols = bsyms[:] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: do we need to copy if we del below? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: I don't think we strictly need the copy here. |
||
intermediate_trace.set_provenance(TraceProvenance("Intermediate trace of `functionalize_inplace_ops`")) | ||
del bsyms | ||
Comment on lines
+418
to
+421
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can't we just do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes. I didn't want to reuse |
||
|
||
# Step 2: Remove `prims.copy_` if it's the last one of `bsym.subsymbols`, | ||
# unless `copy_to` is `computation_trace.args` or `computation_trace.kwargs` | ||
trace_args_set = ProxyDict() | ||
for a in filter( | ||
lambda a: isinstance(a, TensorProxy), tree_flatten((computation_trace.args, computation_trace.kwargs))[0] | ||
): | ||
trace_args_set[a] = a | ||
bsym_inplace_to_functional = {} | ||
swap_map.clear() | ||
new_bsyms: list[BoundSymbol] = [] | ||
for bsym in intermediate_trace.bound_symbols: | ||
new_bsym = bsym.from_bsym_swap_proxies(swap_map) | ||
|
||
if not is_functionalizable(new_bsym): | ||
new_bsyms.append(new_bsym) | ||
continue | ||
copy_bsym = bsym.subsymbols[-1] | ||
copy_return = copy_bsym.flat_proxy_outs[0] | ||
copy_from = copy_bsym.flat_proxy_args[0] | ||
copy_to = copy_bsym.flat_proxy_args[1] | ||
if copy_to in trace_args_set: | ||
new_bsyms.append(new_bsym) | ||
else: | ||
swap_map[variableify(copy_return)] = copy_from | ||
new_bsym.subsymbols = new_bsym.subsymbols[:-1] | ||
new_bsym = new_bsym.from_bsym_swap_proxies(swap_map) | ||
Comment on lines
+439
to
+448
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic looks similar to what step 1 is doing. Couldn't they be merged? It seems like the whole thing could be done in a single pass? |
||
|
||
functional_sym: Symbol | ||
optional_inplace_arg_index: int | ||
functional_sym, optional_inplace_arg_index = thunder.torch._inplace_to_out_of_place[new_bsym.sym] | ||
|
||
flat_args, flat_args_spec = tree_flatten((new_bsym.args, new_bsym.kwargs)) | ||
if optional_inplace_arg_index > -1: | ||
flat_args[optional_inplace_arg_index] = False | ||
Comment on lines
+455
to
+456
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This probably needs a comment. |
||
args, kwargs = tree_unflatten(flat_args, flat_args_spec) | ||
new_functional_bsym = functional_sym.bind( | ||
*args, | ||
**kwargs, | ||
output=new_bsym.output, | ||
subsymbols=new_bsym.subsymbols, | ||
_call_ctx=new_bsym._call_ctx, | ||
) | ||
new_bsyms.append(new_functional_bsym) | ||
bsym_inplace_to_functional[new_bsym] = new_functional_bsym | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you're right but I once tried to register this as an attribute of provenance at L473 |
||
|
||
functionalized_computation_trace = from_trace(computation_trace) | ||
functionalized_computation_trace.bound_symbols = new_bsyms | ||
functionalized_computation_trace.set_provenance(TraceProvenance("Functionalize in-place ops")) | ||
# note(crcrpar): I kind of want to do the following two. | ||
# functionalized_computation_trace._provenance.swap_map = swap_map | ||
# functionalized_computation_trace._provenance.bsym_inplace_to_functional = bsym_inplace_to_functional | ||
return [intermediate_trace, functionalized_computation_trace] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Longer term, I wonder if we should have a set of default transformations and this be one of them, but for now it is OK.