Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformer_engine: wrap checker_fn in langctx and cleanup (PR2473) #24

Merged
merged 5 commits into from
Mar 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions thunder/executors/transformer_engineex.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from thunder.core.proxies import TensorProxy, CollectionProxy
from thunder.core.symbol import Symbol
from thunder.extend import OperatorExecutor, register_executor
from thunder.core.langctxs import langctx, Languages

__all__ = [
"transformer_engine_ex",
Expand Down Expand Up @@ -369,6 +370,10 @@ def bind_postprocess(bsym: BoundSymbol) -> None:
#
# Registers transformer_engine_ex as an executor for torch.nn.functional.linear
#


# NOTE: We need langctx so that we can resolve `view` on TensorProxy.
@langctx(Languages.TORCH)
def _linear_checker(
a: TensorProxy,
w: TensorProxy,
Expand Down Expand Up @@ -398,15 +403,6 @@ def linear_forwad_rule(a, w, bias):
return primal, saved_for_backward


def linear_forward_rule_checker(a: TensorProxy, w: TensorProxy, bias: None | TensorProxy) -> bool:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
from thunder.core.compile_data import get_compile_data

cd = get_compile_data()
if transformer_engine_ex in cd.executors_list:
return _linear_checker(a, w, bias)
return False


def linear_backward_rule(a_shape, w_shape, b_shape, ctx_idx, grad):
return te_functional_linear_backward(grad, a_shape, w_shape, b_shape, ctx_idx)

Expand Down
Loading