diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index d41a9ae433..81ddffe369 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -1649,6 +1649,7 @@ def gen(op, device, dtype, requires_grad): dtypes=(datatypes.floating,), sample_input_generator=get_elementwise_unary_with_alpha_generator(), torch_reference=_elementwise_unary_torch(torch.celu), + singularity_fn=lambda x: x, test_directives=(), ) elementwise_unary_ops.append(celu_opinfo)