Skip to content

Commit

Permalink
support no_grad in thunder.jit
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 committed Nov 11, 2024
1 parent f7b2e15 commit d646dd1
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 4 deletions.
4 changes: 4 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
wrap_return_value_together_with_argments,
unwrap_return_value,
remove_context_manager_prims_from_trace,
tag_no_grad_symbols_pass,
)
from thunder.core.functionalization import (
check_inplace_to_views,
Expand Down Expand Up @@ -538,6 +539,9 @@ def get_computation_and_inputs(*args, **kwargs):
computation_trc = wrap_return_value_together_with_argments(computation_trc)
computation_traces.append(computation_trc)

computation_trc = tag_no_grad_symbols_pass(computation_trc)
computation_traces.append(computation_trc)

computation_trc = remove_context_manager_prims_from_trace(computation_trc)
computation_traces.append(computation_trc)

Expand Down
51 changes: 50 additions & 1 deletion thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import thunder
import thunder.core.prims as prims
from thunder.core.baseutils import BoundSymbolInterface
from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy, unvariableify
from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy, unvariableify, ProxyTag
from thunder.core.pytree import tree_flatten, tree_iter, tree_map, tree_unflatten
from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, has_tags
from thunder.core.trace import from_trace, TraceProvenance, TraceCtx as Trace, tracectx
Expand Down Expand Up @@ -492,3 +492,52 @@ def is_context_manager_prim(bsym):
new_trace.bound_symbols = filtered_bsyms
new_trace.set_provenance(TraceProvenance("Remove context manager prims"))
return new_trace


def tag_no_grad_symbols_pass(trace: Trace) -> Trace:
"""
This function iterates over trace and marks the BoundSymbols
in `no_grad` regions such that VJP pass will treat them as constant.
"""
is_no_grad_region = False

# NOTE - This will also copy name from original trace.
new_trace = from_trace(trace)
new_bsyms = []

for bsym in trace.bound_symbols:
# case - torch._C._set_grad_enabled(False)
if bsym.sym.id == thunder.torch._set_grad_enabled_with_warning.id and not bsym.args[0]:
is_no_grad_region = True
continue
# case - torch._C._set_grad_enabled(True)
elif bsym.sym.id == thunder.torch._set_grad_enabled_with_warning.id and bsym.args[0]:
is_no_grad_region = False
continue

if is_no_grad_region:
# Mark the TensorProxy output of the `bsym`
# with `ProxyTag.DETACHED_AUTOGRAD_GRAPH` so that
# vjp will treat this as constant.

def create_detached_output(t):
if isinstance(t, TensorProxy):
# NOTE - We need `tracectx` as creating/replacing name for proxy
# tries a look-up in current trace.
with tracectx(new_trace):
# Remove the name so that we can re-use it.
# Otherwise, we get a proxy with new name.
new_trace.names.remove(t.name)
return t.replace(requires_grad=False, tags=(ProxyTag.DETACHED_AUTOGRAD_GRAPH,))

return t

new_output = tree_map(create_detached_output, bsym.output)
# Create a copy of the `bsym` with `new_output`
bsym = bsym.from_bsym(output=new_output)

new_bsyms.append(bsym)

new_trace.bound_symbols = new_bsyms
new_trace.set_provenance(TraceProvenance("no_grad detach graph for vjp pass"))
return new_trace
8 changes: 8 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
variableify,
unvariableify,
FutureTensorProxy,
ProxyTag,
)
from thunder.core.compile_data import get_compile_data, get_compile_option
from thunder.core.langctxs import langctx, Languages
Expand Down Expand Up @@ -2485,10 +2486,17 @@ def is_constant_for_vjp(symbol: prims.Symbol) -> bool:
bool: True if the symbol is constant, False otherwise.
"""
are_all_args_non_differentiable = not any(isinstance(arg, (FloatProxy, TensorProxy)) for arg in symbol.flat_args)
# `no_grad_detach_graph_pass` tags output of BoundSymbols in `torch.no_grad` regions with `DETACHED_AUTOGRAD_GRAPH`.
# These are treated as constant for VJP.
# NOTE - `any(()) is False`
output_disconnected_from_graph = any(
ProxyTag.DETACHED_AUTOGRAD_GRAPH in o.tags for o in symbol.flat_outs if isinstance(o, TensorProxy)
)
return (
are_all_args_non_differentiable
or symbol.are_all_args_constant
or symbol.sym.id in nondifferentiable_vjp_symbols
or output_disconnected_from_graph
)


Expand Down
14 changes: 11 additions & 3 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
ListProxy,
DictProxy,
numberproxy,
ProxyTag,
)
from thunder.core.pytree import tree_map, tree_flatten, tree_unflatten
from thunder.core.symbol import Symbol
Expand Down Expand Up @@ -5226,11 +5227,18 @@ def torch_device(type: DeviceLike, index: int | None = None) -> devices.Device:
register_function(torch.device, torch_device)


def _set_grad_enabled_with_warning(enabled: bool) -> None:
warnings.warn("torch.enable_grad/torch.no_grad/torch._C._set_grad_enabled have no effect under thunder.jit")
# Tag to use on Proxies created in `no_grad` regions.
# VJP transform will treat BOundSymbol's whose output has these tags
# as constant.
ProxyTag.register_tag("DETACHED_AUTOGRAD_GRAPH")


register_function(torch._C._set_grad_enabled, _set_grad_enabled_with_warning)
# This is just a marker Symbol. `tag_no_grad_symbols_pass` pass uses these symbols
# to find the `no_grad` regions and mark the BoundSymbols within them as constant
# for VJP using the `DETACHED_AUTOGRAD_GRAPH` tag.
@torchsymbol(torch._C._set_grad_enabled, id="set_grad_enabled", tags=(prims.OpTags.CTX_MANAGER_ENTER_EXIT_OP,))
def _set_grad_enabled_with_warning(enabled: bool) -> None:
pass


def _unwrap_if_dead(tensor):
Expand Down

0 comments on commit d646dd1

Please sign in to comment.