Skip to content

Commit

Permalink
transformer_engine: wrap checker_fn in langctx and cleanup (PR2473) (#24
Browse files Browse the repository at this point in the history
)

Co-authored-by: Thomas Viehmann <tv.code@beamnet.de>
  • Loading branch information
kshitij12345 and t-vi authored Mar 25, 2024
1 parent 2e0bb61 commit b873afa
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions thunder/executors/transformer_engineex.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from thunder.core.proxies import TensorProxy, CollectionProxy
from thunder.core.symbol import Symbol
from thunder.extend import OperatorExecutor, register_executor
from thunder.core.langctxs import langctx, Languages

__all__ = [
"transformer_engine_ex",
Expand Down Expand Up @@ -369,6 +370,10 @@ def bind_postprocess(bsym: BoundSymbol) -> None:
#
# Registers transformer_engine_ex as an executor for torch.nn.functional.linear
#


# NOTE: We need langctx so that we can resolve `view` on TensorProxy.
@langctx(Languages.TORCH)
def _linear_checker(
a: TensorProxy,
w: TensorProxy,
Expand Down Expand Up @@ -398,15 +403,6 @@ def linear_forwad_rule(a, w, bias):
return primal, saved_for_backward


def linear_forward_rule_checker(a: TensorProxy, w: TensorProxy, bias: None | TensorProxy) -> bool:
from thunder.core.compile_data import get_compile_data

cd = get_compile_data()
if transformer_engine_ex in cd.executors_list:
return _linear_checker(a, w, bias)
return False


def linear_backward_rule(a_shape, w_shape, b_shape, ctx_idx, grad):
return te_functional_linear_backward(grad, a_shape, w_shape, b_shape, ctx_idx)

Expand Down

0 comments on commit b873afa

Please sign in to comment.