diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index bdd322966d..f02f3312ba 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -7647,6 +7647,56 @@ def layer_norm_error_generator(op, device, **kwargs): nn_ops.append(layer_norm_opinfo) +def rms_norm_reference_generator(op, device, dtype, requires_grad, **kwargs): + for sample_inputs in layer_norm_reference_generator(op, device, dtype, requires_grad, **kwargs): + print(sample_inputs.args) + if len(sample_inputs.args) > 3: # positional bias + sample_inputs.args = sample_inputs.args[:3] + sample_inputs.args[4:] + sample_inputs.kwargs.pop("bias", None) + yield sample_inputs + + +def rms_norm_sample_generator(op, device, dtype, requires_grad, **kwargs): + for sample_inputs in layer_norm_sample_generator(op, device, dtype, requires_grad, **kwargs): + print(sample_inputs.args) + if len(sample_inputs.args) > 3: # positional bias + sample_inputs.args = sample_inputs.args[:3] + sample_inputs.args[4:] + sample_inputs.kwargs.pop("bias", None) + yield sample_inputs + + +def rms_norm_error_generator(op, device, **kwargs): + for sample_inputs, exc_type, msg in layer_norm_error_generator(op, device, **kwargs): + print(sample_inputs.args) + if len(sample_inputs.args) > 3: # positional bias + sample_inputs.args = sample_inputs.args[:3] + sample_inputs.args[4:] + sample_inputs.kwargs.pop("bias", None) + if "bias" not in msg: + yield sample_inputs, exc_type, msg + + +if LooseVersion(torch.__version__) >= "2.4": + rms_norm_opinfo = OpInfo( + ltorch.rms_norm, + sample_input_generator=rms_norm_sample_generator, + error_input_generator=rms_norm_error_generator, + reference_input_generator=rms_norm_reference_generator, + torch_reference=torch.nn.functional.rms_norm, + # Complex var is not supported yet + dtypes=(datatypes.floating,), + test_directives=( + # PyTorch does not support float16 on CPU + DecorateInfo( + pytest.mark.xfail, + "test_core_vs_torch_consistency", + dtypes=(datatypes.float16,), + devicetypes=(devices.DeviceType.CPU,), + ), + ), + ) + nn_ops.append(rms_norm_opinfo) + + def batch_norm_reference_generator(op, device, dtype, requires_grad, **kwargs): yield from layer_norm_sample_generator(op, device, dtype, requires_grad, **kwargs) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 7e27f1fbb3..4056dee2c9 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -3584,10 +3584,7 @@ def normalize( return out -# TODO: likely want to refactor these normalizations -def _native_layer_norm( - a: TensorProxy, /, normalized_shape, weight, bias, eps: Number -) -> tuple[TensorLike, TensorLike, TensorLike]: +def _check_normalized_shape_and_get_reduction_dims(a, normalized_shape, weight=None, bias=None): # Validates inputs normalized_ndim = len(normalized_shape) utils.check(normalized_ndim >= 1, lambda: f"Expected normalized_shape={normalized_shape} to have length >= 1!") @@ -3613,6 +3610,14 @@ def _native_layer_norm( axis = a.ndim - normalized_ndim reduction_dims = list(range(axis, a.ndim)) + return reduction_dims + + +# TODO: likely want to refactor these normalizations +def _native_layer_norm( + a: TensorProxy, /, normalized_shape, weight, bias, eps: Number +) -> tuple[TensorLike, TensorLike, TensorLike]: + reduction_dims = _check_normalized_shape_and_get_reduction_dims(a, normalized_shape, weight, bias) out, mean, rstd = _normalize(a, reduction_dims, eps) # Handles weight and bias @@ -3653,6 +3658,27 @@ def layer_norm( return _native_layer_norm(a, normalized_shape, weight, bias, eps)[0] +def rms_norm( + a: TensorLike, + /, + normalized_shape: Sequence[int], + weight: None | TensorLike = None, + eps: None | float = None, +): + if eps is None: + eps = torch.finfo(to_torch_dtype(a.dtype)).eps + reduction_dims = _check_normalized_shape_and_get_reduction_dims(a, normalized_shape, weight) + norm_a = mean(a * a, dim=reduction_dims, keepdim=True) + a_normed = a * rsqrt(norm_a + eps) + if weight is not None: + a_normed = a_normed * weight + return a_normed + + +if hasattr(torch.nn.functional, "rms_norm"): + rms_norm = torchsymbol(torch.nn.functional.rms_norm)(rms_norm) + + def _native_batch_norm( a: TensorLike, /,