Skip to content

add rms_norm #1390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -7647,6 +7647,56 @@ def layer_norm_error_generator(op, device, **kwargs):
nn_ops.append(layer_norm_opinfo)


def rms_norm_reference_generator(op, device, dtype, requires_grad, **kwargs):
for sample_inputs in layer_norm_reference_generator(op, device, dtype, requires_grad, **kwargs):
print(sample_inputs.args)
if len(sample_inputs.args) > 3: # positional bias
sample_inputs.args = sample_inputs.args[:3] + sample_inputs.args[4:]
sample_inputs.kwargs.pop("bias", None)
yield sample_inputs


def rms_norm_sample_generator(op, device, dtype, requires_grad, **kwargs):
for sample_inputs in layer_norm_sample_generator(op, device, dtype, requires_grad, **kwargs):
print(sample_inputs.args)
if len(sample_inputs.args) > 3: # positional bias
sample_inputs.args = sample_inputs.args[:3] + sample_inputs.args[4:]
sample_inputs.kwargs.pop("bias", None)
yield sample_inputs


def rms_norm_error_generator(op, device, **kwargs):
for sample_inputs, exc_type, msg in layer_norm_error_generator(op, device, **kwargs):
print(sample_inputs.args)
if len(sample_inputs.args) > 3: # positional bias
sample_inputs.args = sample_inputs.args[:3] + sample_inputs.args[4:]
sample_inputs.kwargs.pop("bias", None)
if "bias" not in msg:
yield sample_inputs, exc_type, msg


if LooseVersion(torch.__version__) >= "2.4":
rms_norm_opinfo = OpInfo(
ltorch.rms_norm,
sample_input_generator=rms_norm_sample_generator,
error_input_generator=rms_norm_error_generator,
reference_input_generator=rms_norm_reference_generator,
torch_reference=torch.nn.functional.rms_norm,
# Complex var is not supported yet
dtypes=(datatypes.floating,),
test_directives=(
# PyTorch does not support float16 on CPU
DecorateInfo(
pytest.mark.xfail,
"test_core_vs_torch_consistency",
dtypes=(datatypes.float16,),
devicetypes=(devices.DeviceType.CPU,),
),
),
)
nn_ops.append(rms_norm_opinfo)


def batch_norm_reference_generator(op, device, dtype, requires_grad, **kwargs):
yield from layer_norm_sample_generator(op, device, dtype, requires_grad, **kwargs)

Expand Down
34 changes: 30 additions & 4 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3584,10 +3584,7 @@ def normalize(
return out


# TODO: likely want to refactor these normalizations
def _native_layer_norm(
a: TensorProxy, /, normalized_shape, weight, bias, eps: Number
) -> tuple[TensorLike, TensorLike, TensorLike]:
def _check_normalized_shape_and_get_reduction_dims(a, normalized_shape, weight=None, bias=None):
# Validates inputs
normalized_ndim = len(normalized_shape)
utils.check(normalized_ndim >= 1, lambda: f"Expected normalized_shape={normalized_shape} to have length >= 1!")
Expand All @@ -3613,6 +3610,14 @@ def _native_layer_norm(

axis = a.ndim - normalized_ndim
reduction_dims = list(range(axis, a.ndim))
return reduction_dims


# TODO: likely want to refactor these normalizations
def _native_layer_norm(
a: TensorProxy, /, normalized_shape, weight, bias, eps: Number
) -> tuple[TensorLike, TensorLike, TensorLike]:
reduction_dims = _check_normalized_shape_and_get_reduction_dims(a, normalized_shape, weight, bias)
out, mean, rstd = _normalize(a, reduction_dims, eps)

# Handles weight and bias
Expand Down Expand Up @@ -3653,6 +3658,27 @@ def layer_norm(
return _native_layer_norm(a, normalized_shape, weight, bias, eps)[0]


def rms_norm(
a: TensorLike,
/,
normalized_shape: Sequence[int],
weight: None | TensorLike = None,
eps: None | float = None,
):
if eps is None:
eps = torch.finfo(to_torch_dtype(a.dtype)).eps
reduction_dims = _check_normalized_shape_and_get_reduction_dims(a, normalized_shape, weight)
norm_a = mean(a * a, dim=reduction_dims, keepdim=True)
a_normed = a * rsqrt(norm_a + eps)
if weight is not None:
a_normed = a_normed * weight
return a_normed


if hasattr(torch.nn.functional, "rms_norm"):
rms_norm = torchsymbol(torch.nn.functional.rms_norm)(rms_norm)


def _native_batch_norm(
a: TensorLike,
/,
Expand Down
Loading