diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 6e683b09df..2f389a480f 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -266,7 +266,7 @@ def _dot(x, y): return sum([_tensor_dot(a, b) for a, b in zip(x, y)]) -def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = False): +def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = False, prologue_required: bool = False): """Check that the vector-Jacobian product of a function is correct. Args: @@ -296,7 +296,13 @@ def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = Fals u = tree_map(make, primals) - comp_f = thunder.jit(f, disable_torch_autograd=True) + # dirty little trick for speed: skip the prologue, however, the prologue is required when + # there are non-differentiable kwargs + jf = executor.make_callable(f, disable_torch_autograd=True) + if prologue_required: + comp_f = thunder.jit(f, disable_torch_autograd=True) + else: + comp_f = thunder.compile_data(jf).get_computation_and_inputs(*primals)[0].computation_fn outs_p, J_u = numerical_jvp(comp_f)(primals, u) @@ -304,7 +310,7 @@ def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = Fals v = tree_map(make, outs_p) if set_compile_data: - with thunder.core.compile_data.compile_data_and_stats(thunder.compile_data(comp_f), None): + with thunder.core.compile_data.compile_data_and_stats(thunder.compile_data(jf), None): initial_trace_vjp_f = thunder.trace()(vjp(f), primals, v) else: initial_trace_vjp_f = thunder.trace()(vjp(f), primals, v) @@ -364,8 +370,15 @@ def wrapper(*differentiable_args): return wrapper, filtered_args -def snippet_vjp_correctness(func, args, comp, executor, set_compile_data): - check_vjp(func, *args, comp=comp, executor=executor, set_compile_data=set_compile_data) +def snippet_vjp_correctness(func, args, comp, executor, set_compile_data, prologue_required): + check_vjp( + func, + *args, + comp=comp, + executor=executor, + set_compile_data=set_compile_data, + prologue_required=prologue_required, + ) # TODO Use the given comparator @@ -408,6 +421,7 @@ def test_vjp_correctness(op, device, dtype, executor, comp): comp, executor, "adaptive_avg_pool2d" in op.name, + len(sample.kwargs) != 0, ) if result is not None: return result