diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 0453d6368..bba6c637b 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -418,27 +418,32 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) end end -function tuple_expression(primal_sig_parts) +function _make_pullback_for_non_differentiable(::Val{N}) where {N} + Vararg{Any,N} # throw early for invalid `N`, must be nonnegative `Int` + function pullback_for_non_differentiable(::Any) + ntuple(Returns(NoTangent()), Val(N)) + end +end + +function tuple_length_expression(primal_sig_parts) has_vararg = _isvararg(primal_sig_parts[end]) return if !has_vararg num_primal_inputs = length(primal_sig_parts) - Expr(:tuple, ntuple(_ -> NoTangent(), num_primal_inputs)...) + :($num_primal_inputs) else num_primal_inputs = length(primal_sig_parts) - 1 # - vararg length_expr = :($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end]))))) - @strip_linenos :(ntuple(i -> NoTangent(), $length_expr)) + @strip_linenos :($length_expr) end end function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) esc_primal_sig_parts = map(esc, primal_sig_parts) - tup_expr = tuple_expression(primal_sig_parts) + tup_len_expr = tuple_length_expression(primal_sig_parts) primal_name = first(primal_invoke.args) pullback_expr = @strip_linenos quote - function $(esc(propagator_name(primal_name, :pullback)))(@nospecialize(_)) - return $(tup_expr) - end + _make_pullback_for_non_differentiable(Val{$(tup_len_expr)}()) end @gensym kwargs diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 43863a915..55f90724e 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -42,6 +42,47 @@ end @testset "rule_definition_tools.jl" begin @testset "@non_differentiable" begin + @testset "`_make_pullback_for_non_differentiable`" begin + f = ChainRulesCore._make_pullback_for_non_differentiable + @testset "throws on invalid input" begin + @test_throws Exception f(Val(0.0)) + @test_throws Exception f(Val(-1)) + end + @testset "identical objects" begin + for i ∈ 0:5 + v = Val(i) + @test f(v) === f(v) + end + end + @testset "correctness" begin + for i ∈ 0:5 + expected = ntuple((_ -> NoTangent()), i) + @test f(Val(i))(:arbitrary) === expected + end + end + @testset "dispatch" begin + for i ∈ 0:5 + pullback = f(Val(i)) + @test_throws MethodError pullback() + @test_throws MethodError pullback(1, 2) + end + end + end + + @testset "issue #678: identical pullback objects" begin + issue_678_f(::Any) = nothing + issue_678_g(::Any) = nothing + issue_678_h(::Any...) = nothing + @non_differentiable issue_678_f(::Any) + @non_differentiable issue_678_g(::Any) + @non_differentiable issue_678_h(::Any...) + @test ( + last(rrule(issue_678_f, 0.1)) === + last(rrule(issue_678_g, 0.2)) === + last(rrule(issue_678_h, 0.3)) + ) + end + @testset "two input one output function" begin nondiff_2_1(x, y) = fill(7.5, 100)[x + y] @non_differentiable nondiff_2_1(::Any, ::Any)