Skip to content

Commit

Permalink
use the prologue in check_vjp
Browse files Browse the repository at this point in the history
  • Loading branch information
beverlylytle committed Nov 12, 2024
1 parent f7b2e15 commit d14bfe4
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 9 deletions.
7 changes: 1 addition & 6 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,12 +1648,7 @@ def celu_sample_generator(op, device, dtype, requires_grad):
dtypes=(datatypes.floating,),
sample_input_generator=celu_sample_generator,
torch_reference=_elementwise_unary_torch(torch.celu),
test_directives=(
DecorateInfo(
custom_comparator(partial(assert_close, atol=1e-6, rtol=1e-6)),
"test_vjp_correctness",
),
),
test_directives=(),
)
elementwise_unary_ops.append(celu_opinfo)

Expand Down
4 changes: 1 addition & 3 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,7 @@ def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = Fals

u = tree_map(make, primals)

# dirty little trick for speed: skip the prologue
jf = executor.make_callable(f, disable_torch_autograd=True)
comp_f = thunder.compile_data(jf).get_computation_and_inputs(*primals)[0].computation_fn
comp_f = thunder.jit(f)

outs_p, J_u = numerical_jvp(comp_f)(primals, u)

Expand Down

0 comments on commit d14bfe4

Please sign in to comment.