Skip to content

Commit

Permalink
TE - support v1.11 (current main) (#1052)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 authored Aug 27, 2024
1 parent d1f563b commit 18e7ece
Showing 1 changed file with 30 additions and 14 deletions.
44 changes: 30 additions & 14 deletions thunder/executors/transformer_engineex.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
# Ex. addition of a positional argument for cpu_offloading (not as the last argument)
# between version 1.2 and 1.3.
# Hence, we have these guards based on version.
TE_VERSION_1_6_PLUS: bool = False
TE_VERSION_1_8_PLUS: bool = False
TE_VERSION_1_11_PLUS: bool = False

te: None | Any = None
if TE_AVAILABLE:
Expand All @@ -55,8 +55,8 @@
warnings.warn(f"transformer_engine failed to import with exception {ex}")
TE_AVAILABLE = False

TE_VERSION_1_6_PLUS = LooseVersion(version("transformer_engine")) > LooseVersion("1.6")
TE_VERSION_1_8_PLUS = LooseVersion(version("transformer_engine")) > LooseVersion("1.8")
TE_VERSION_1_11_PLUS = LooseVersion(version("transformer_engine")) > LooseVersion("1.11")
if not TE_VERSION_1_8_PLUS:
warnings.warn(
f"Installed version of transformer_engine {version('transformer_engine')} is not supported, please upgrade. `transformer_engine_ex` will not be used."
Expand Down Expand Up @@ -348,19 +348,33 @@ def _te_functional_linear_meta(
global LINEAR_CALLS_COUNTER
ctx_dict = AnyProxy(object(), name=f"ctx_te_{LINEAR_CALLS_COUNTER}")

# https://github.com/NVIDIA/TransformerEngine/blob/37280ecd5e9c6087d18fbe2e668f2ec7761ada3d/transformer_engine/pytorch/module/linear.py#L323-L330
# It's not critical to model the exact shape and dtype of
# saved_tensors since they are not used in Thunder's meta functions.
saved_tensors = (
TensorProxy(like=a, shape=a.shape), # saved_inputmat
TensorProxy(like=a, shape=(a.shape[:-2] + (a.shape[-1], a.shape[-2]))), # saved_inputmat_t
TensorProxy(like=w, shape=w.shape), # weight
TensorProxy(like=w, shape=(w.shape[1], w.shape[0]), dtype=float8_e4m3fn), # weight_fp8
# fuse_wgrad_accumulation is False
# https://github.com/Lightning-AI/lightning-thunder/blob/40da5bd5fabc30e99883d74b70c6a7d7fd61a828/thunder/executors/transformer_engineex.py#L224
None, # weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
TensorProxy(like=a, shape=(1,)), # scaling_fwd
)
if TE_VERSION_1_11_PLUS:
# NOTE - Position of scaling tensor denoted by `scaling_fwd` is different (compared to `else`).
# https://github.com/NVIDIA/TransformerEngine/blob/7fc50f489b8184fbd93efd4e48140ad0264e362b/transformer_engine/pytorch/module/linear.py#L330C13-L337C14
saved_tensors = (
TensorProxy(like=a, shape=a.shape), # saved_inputmat
TensorProxy(like=a, shape=(a.shape[:-2] + (a.shape[-1], a.shape[-2]))), # saved_inputmat_t
TensorProxy(like=a, shape=(1,)), # scaling_fwd
TensorProxy(like=w, shape=w.shape), # weight
TensorProxy(like=w, shape=(w.shape[1], w.shape[0]), dtype=float8_e4m3fn), # weight_fp8
# fuse_wgrad_accumulation is False
# https://github.com/Lightning-AI/lightning-thunder/blob/40da5bd5fabc30e99883d74b70c6a7d7fd61a828/thunder/executors/transformer_engineex.py#L224
None, # weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
)
else:
# https://github.com/NVIDIA/TransformerEngine/blob/37280ecd5e9c6087d18fbe2e668f2ec7761ada3d/transformer_engine/pytorch/module/linear.py#L323-L330
saved_tensors = (
TensorProxy(like=a, shape=a.shape), # saved_inputmat
TensorProxy(like=a, shape=(a.shape[:-2] + (a.shape[-1], a.shape[-2]))), # saved_inputmat_t
TensorProxy(like=w, shape=w.shape), # weight
TensorProxy(like=w, shape=(w.shape[1], w.shape[0]), dtype=float8_e4m3fn), # weight_fp8
# fuse_wgrad_accumulation is False
# https://github.com/Lightning-AI/lightning-thunder/blob/40da5bd5fabc30e99883d74b70c6a7d7fd61a828/thunder/executors/transformer_engineex.py#L224
None, # weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
TensorProxy(like=a, shape=(1,)), # scaling_fwd
)

return TensorProxy(like=a, shape=output_shape), saved_tensors, ctx_dict
return TensorProxy(like=a, shape=output_shape), None, None
Expand All @@ -381,9 +395,11 @@ def _te_functional_linear_backward_impl(
) -> [torch.Tensor, torch.Tensor, None | torch.Tensor]:
# See [NOTE] Enable grad within context
# _Linear.backward depends on requires grad of `weight/ctx.saved_tensors[2]`.
# NOTE - weight is ctx.saved_tensors[3] from TE v1.11 onwards
# Hence we enable requires_grad for computation.
# https://github.com/NVIDIA/TransformerEngine/blob/b957aa475bcbcf22405381d18bd7fefe4fb6b171/transformer_engine/pytorch/module/linear.py#L434
with set_saved_tensors(ctx, saved_tensors), enable_grad(saved_tensors[2]):
weight_t = saved_tensors[3] if TE_VERSION_1_11_PLUS else saved_tensors[2]
with set_saved_tensors(ctx, saved_tensors), enable_grad(weight_t):
grads = _Linear.backward(ctx, g)

# Due to different in `_Linear.forward` API, position of
Expand Down

0 comments on commit 18e7ece

Please sign in to comment.