diff --git a/thunder/executors/transformer_engineex.py b/thunder/executors/transformer_engineex.py index ced5d8fdb1..ba53c3e940 100644 --- a/thunder/executors/transformer_engineex.py +++ b/thunder/executors/transformer_engineex.py @@ -20,6 +20,7 @@ from thunder.core.proxies import TensorProxy, CollectionProxy from thunder.core.symbol import Symbol from thunder.extend import OperatorExecutor, register_executor +from thunder.core.langctxs import langctx, Languages __all__ = [ "transformer_engine_ex", @@ -369,6 +370,10 @@ def bind_postprocess(bsym: BoundSymbol) -> None: # # Registers transformer_engine_ex as an executor for torch.nn.functional.linear # + + +# NOTE: We need langctx so that we can resolve `view` on TensorProxy. +@langctx(Languages.TORCH) def _linear_checker( a: TensorProxy, w: TensorProxy, @@ -398,15 +403,6 @@ def linear_forwad_rule(a, w, bias): return primal, saved_for_backward -def linear_forward_rule_checker(a: TensorProxy, w: TensorProxy, bias: None | TensorProxy) -> bool: - from thunder.core.compile_data import get_compile_data - - cd = get_compile_data() - if transformer_engine_ex in cd.executors_list: - return _linear_checker(a, w, bias) - return False - - def linear_backward_rule(a_shape, w_shape, b_shape, ctx_idx, grad): return te_functional_linear_backward(grad, a_shape, w_shape, b_shape, ctx_idx)