diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index f02f3312ba..48a0d3d67e 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -7692,6 +7692,12 @@ def rms_norm_error_generator(op, device, **kwargs): dtypes=(datatypes.float16,), devicetypes=(devices.DeviceType.CPU,), ), + # See issue - https://github.com/Lightning-AI/lightning-thunder/issues/1395 + DecorateInfo( + custom_comparator(partial(assert_close, atol=2e-3, rtol=2e-3)), + dtypes=(datatypes.float16), + devicetypes=(devices.DeviceType.CUDA), + ), ), ) nn_ops.append(rms_norm_opinfo)