Skip to content

Commit

Permalink
add rms_norm (#1390)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Nov 1, 2024
1 parent 8535eed commit a24e86e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 4 deletions.
50 changes: 50 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -7647,6 +7647,56 @@ def layer_norm_error_generator(op, device, **kwargs):
nn_ops.append(layer_norm_opinfo)


def rms_norm_reference_generator(op, device, dtype, requires_grad, **kwargs):
for sample_inputs in layer_norm_reference_generator(op, device, dtype, requires_grad, **kwargs):
print(sample_inputs.args)
if len(sample_inputs.args) > 3: # positional bias
sample_inputs.args = sample_inputs.args[:3] + sample_inputs.args[4:]
sample_inputs.kwargs.pop("bias", None)
yield sample_inputs


def rms_norm_sample_generator(op, device, dtype, requires_grad, **kwargs):
for sample_inputs in layer_norm_sample_generator(op, device, dtype, requires_grad, **kwargs):
print(sample_inputs.args)
if len(sample_inputs.args) > 3: # positional bias
sample_inputs.args = sample_inputs.args[:3] + sample_inputs.args[4:]
sample_inputs.kwargs.pop("bias", None)
yield sample_inputs


def rms_norm_error_generator(op, device, **kwargs):
for sample_inputs, exc_type, msg in layer_norm_error_generator(op, device, **kwargs):
print(sample_inputs.args)
if len(sample_inputs.args) > 3: # positional bias
sample_inputs.args = sample_inputs.args[:3] + sample_inputs.args[4:]
sample_inputs.kwargs.pop("bias", None)
if "bias" not in msg:
yield sample_inputs, exc_type, msg


if LooseVersion(torch.__version__) >= "2.4":
rms_norm_opinfo = OpInfo(
ltorch.rms_norm,
sample_input_generator=rms_norm_sample_generator,
error_input_generator=rms_norm_error_generator,
reference_input_generator=rms_norm_reference_generator,
torch_reference=torch.nn.functional.rms_norm,
# Complex var is not supported yet
dtypes=(datatypes.floating,),
test_directives=(
# PyTorch does not support float16 on CPU
DecorateInfo(
pytest.mark.xfail,
"test_core_vs_torch_consistency",
dtypes=(datatypes.float16,),
devicetypes=(devices.DeviceType.CPU,),
),
),
)
nn_ops.append(rms_norm_opinfo)


def batch_norm_reference_generator(op, device, dtype, requires_grad, **kwargs):
yield from layer_norm_sample_generator(op, device, dtype, requires_grad, **kwargs)

Expand Down
34 changes: 30 additions & 4 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3584,10 +3584,7 @@ def normalize(
return out


# TODO: likely want to refactor these normalizations
def _native_layer_norm(
a: TensorProxy, /, normalized_shape, weight, bias, eps: Number
) -> tuple[TensorLike, TensorLike, TensorLike]:
def _check_normalized_shape_and_get_reduction_dims(a, normalized_shape, weight=None, bias=None):
# Validates inputs
normalized_ndim = len(normalized_shape)
utils.check(normalized_ndim >= 1, lambda: f"Expected normalized_shape={normalized_shape} to have length >= 1!")
Expand All @@ -3613,6 +3610,14 @@ def _native_layer_norm(

axis = a.ndim - normalized_ndim
reduction_dims = list(range(axis, a.ndim))
return reduction_dims


# TODO: likely want to refactor these normalizations
def _native_layer_norm(
a: TensorProxy, /, normalized_shape, weight, bias, eps: Number
) -> tuple[TensorLike, TensorLike, TensorLike]:
reduction_dims = _check_normalized_shape_and_get_reduction_dims(a, normalized_shape, weight, bias)
out, mean, rstd = _normalize(a, reduction_dims, eps)

# Handles weight and bias
Expand Down Expand Up @@ -3653,6 +3658,27 @@ def layer_norm(
return _native_layer_norm(a, normalized_shape, weight, bias, eps)[0]


def rms_norm(
a: TensorLike,
/,
normalized_shape: Sequence[int],
weight: None | TensorLike = None,
eps: None | float = None,
):
if eps is None:
eps = torch.finfo(to_torch_dtype(a.dtype)).eps
reduction_dims = _check_normalized_shape_and_get_reduction_dims(a, normalized_shape, weight)
norm_a = mean(a * a, dim=reduction_dims, keepdim=True)
a_normed = a * rsqrt(norm_a + eps)
if weight is not None:
a_normed = a_normed * weight
return a_normed


if hasattr(torch.nn.functional, "rms_norm"):
rms_norm = torchsymbol(torch.nn.functional.rms_norm)(rms_norm)


def _native_batch_norm(
a: TensorLike,
/,
Expand Down

0 comments on commit a24e86e

Please sign in to comment.