Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support no_grad in thunder.jit #1423

Merged
merged 13 commits into from
Nov 18, 2024
4 changes: 4 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
wrap_return_value_together_with_argments,
unwrap_return_value,
remove_context_manager_prims_from_trace,
tag_no_grad_symbols_pass,
)
from thunder.core.functionalization import (
check_inplace_to_views,
Expand Down Expand Up @@ -538,6 +539,9 @@ def get_computation_and_inputs(*args, **kwargs):
computation_trc = wrap_return_value_together_with_argments(computation_trc)
computation_traces.append(computation_trc)

computation_trc = tag_no_grad_symbols_pass(computation_trc)
computation_traces.append(computation_trc)

computation_trc = remove_context_manager_prims_from_trace(computation_trc)
computation_traces.append(computation_trc)

Expand Down
51 changes: 50 additions & 1 deletion thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import thunder
import thunder.core.prims as prims
from thunder.core.baseutils import BoundSymbolInterface
from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy, unvariableify
from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy, unvariableify, ProxyTag
from thunder.core.pytree import tree_flatten, tree_iter, tree_map, tree_unflatten
from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, has_tags
from thunder.core.trace import from_trace, TraceProvenance, TraceCtx as Trace, tracectx
Expand Down Expand Up @@ -492,3 +492,52 @@ def is_context_manager_prim(bsym):
new_trace.bound_symbols = filtered_bsyms
new_trace.set_provenance(TraceProvenance("Remove context manager prims"))
return new_trace


def tag_no_grad_symbols_pass(trace: Trace) -> Trace:
"""
This function iterates over trace and marks the BoundSymbols
in `no_grad` regions such that VJP pass will treat them as constant.
"""
is_no_grad_region = False

# NOTE - This will also copy name from original trace.
new_trace = from_trace(trace)
new_bsyms = []

for bsym in trace.bound_symbols:
# case - torch._C._set_grad_enabled(False)
if bsym.sym.id == thunder.torch._set_grad_enabled_with_warning.id and not bsym.args[0]:
is_no_grad_region = True
continue
# case - torch._C._set_grad_enabled(True)
elif bsym.sym.id == thunder.torch._set_grad_enabled_with_warning.id and bsym.args[0]:
is_no_grad_region = False
continue

if is_no_grad_region:
# Mark the TensorProxy output of the `bsym`
# with `ProxyTag.DETACHED_AUTOGRAD_GRAPH` so that
# vjp will treat this as constant.

def create_detached_output(t):
if isinstance(t, TensorProxy):
# NOTE - We need `tracectx` as creating/replacing name for proxy
# tries a look-up in current trace.
with tracectx(new_trace):
# Remove the name so that we can re-use it.
# Otherwise, we get a proxy with new name.
new_trace.names.remove(t.name)
return t.replace(requires_grad=False, tags=(ProxyTag.DETACHED_AUTOGRAD_GRAPH,))
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved

return t

new_output = tree_map(create_detached_output, bsym.output)
# Create a copy of the `bsym` with `new_output`
bsym = bsym.from_bsym(output=new_output)

new_bsyms.append(bsym)

new_trace.bound_symbols = new_bsyms
new_trace.set_provenance(TraceProvenance("no_grad detach graph for vjp pass"))
return new_trace
8 changes: 8 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
variableify,
unvariableify,
FutureTensorProxy,
ProxyTag,
)
from thunder.core.compile_data import get_compile_data, get_compile_option
from thunder.core.langctxs import langctx, Languages
Expand Down Expand Up @@ -2485,10 +2486,17 @@ def is_constant_for_vjp(symbol: prims.Symbol) -> bool:
bool: True if the symbol is constant, False otherwise.
"""
are_all_args_non_differentiable = not any(isinstance(arg, (FloatProxy, TensorProxy)) for arg in symbol.flat_args)
# `no_grad_detach_graph_pass` tags output of BoundSymbols in `torch.no_grad` regions with `DETACHED_AUTOGRAD_GRAPH`.
# These are treated as constant for VJP.
# NOTE - `any(()) is False`
output_disconnected_from_graph = any(
ProxyTag.DETACHED_AUTOGRAD_GRAPH in o.tags for o in symbol.flat_outs if isinstance(o, TensorProxy)
)
return (
are_all_args_non_differentiable
or symbol.are_all_args_constant
or symbol.sym.id in nondifferentiable_vjp_symbols
or output_disconnected_from_graph
)


Expand Down
14 changes: 11 additions & 3 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
ListProxy,
DictProxy,
numberproxy,
ProxyTag,
)
from thunder.core.pytree import tree_map, tree_flatten, tree_unflatten
from thunder.core.symbol import Symbol
Expand Down Expand Up @@ -5226,11 +5227,18 @@ def torch_device(type: DeviceLike, index: int | None = None) -> devices.Device:
register_function(torch.device, torch_device)


def _set_grad_enabled_with_warning(enabled: bool) -> None:
warnings.warn("torch.enable_grad/torch.no_grad/torch._C._set_grad_enabled have no effect under thunder.jit")
# Tag to use on Proxies created in `no_grad` regions.
# VJP transform will treat BOundSymbol's whose output has these tags
# as constant.
ProxyTag.register_tag("DETACHED_AUTOGRAD_GRAPH")


register_function(torch._C._set_grad_enabled, _set_grad_enabled_with_warning)
# This is just a marker Symbol. `tag_no_grad_symbols_pass` pass uses these symbols
# to find the `no_grad` regions and mark the BoundSymbols within them as constant
# for VJP using the `DETACHED_AUTOGRAD_GRAPH` tag.
@torchsymbol(torch._C._set_grad_enabled, id="set_grad_enabled", tags=(prims.OpTags.CTX_MANAGER_ENTER_EXIT_OP,))
def _set_grad_enabled_with_warning(enabled: bool) -> None:
pass


def _unwrap_if_dead(tensor):
Expand Down
Loading