diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 46796397d0..6f8dfcbf5e 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -1648,12 +1648,7 @@ def celu_sample_generator(op, device, dtype, requires_grad): dtypes=(datatypes.floating,), sample_input_generator=celu_sample_generator, torch_reference=_elementwise_unary_torch(torch.celu), - test_directives=( - DecorateInfo( - custom_comparator(partial(assert_close, atol=1e-6, rtol=1e-6)), - "test_vjp_correctness", - ), - ), + test_directives=(), ) elementwise_unary_ops.append(celu_opinfo) diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 25f42ccb87..8f64ea4493 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -296,9 +296,7 @@ def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = Fals u = tree_map(make, primals) - # dirty little trick for speed: skip the prologue - jf = executor.make_callable(f, disable_torch_autograd=True) - comp_f = thunder.compile_data(jf).get_computation_and_inputs(*primals)[0].computation_fn + comp_f = thunder.jit(f) outs_p, J_u = numerical_jvp(comp_f)(primals, u) @@ -306,7 +304,7 @@ def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = Fals v = tree_map(make, outs_p) if set_compile_data: - with thunder.core.compile_data.compile_data_and_stats(thunder.compile_data(jf), None): + with thunder.core.compile_data.compile_data_and_stats(thunder.compile_data(comp_f), None): initial_trace_vjp_f = thunder.trace()(vjp(f), primals, v) else: initial_trace_vjp_f = thunder.trace()(vjp(f), primals, v)