Skip to content

Commit

Permalink
make @non_differentiable use identical pullbacks when possible
Browse files Browse the repository at this point in the history
Fixes #678
  • Loading branch information
nsajko committed May 31, 2024
1 parent fa530b9 commit 191eb47
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,14 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke)
end
end

struct NonDiffPullback{T<:Tuple{Vararg{NoTangent}}} <: Function
v::T
end

function (@nospecialize pb::NonDiffPullback)(@nospecialize ::Any)
return pb.v
end

function tuple_expression(primal_sig_parts)
has_vararg = _isvararg(primal_sig_parts[end])
return if !has_vararg
Expand All @@ -436,9 +444,7 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke)
tup_expr = tuple_expression(primal_sig_parts)
primal_name = first(primal_invoke.args)
pullback_expr = @strip_linenos quote
function $(esc(propagator_name(primal_name, :pullback)))(@nospecialize(_))
return $(tup_expr)
end
NonDiffPullback($(tup_expr))
end

@gensym kwargs
Expand Down
26 changes: 26 additions & 0 deletions test/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,32 @@ end

@testset "rule_definition_tools.jl" begin
@testset "@non_differentiable" begin
@testset "`NonDiffPullback`" begin
NDP = ChainRulesCore.NonDiffPullback
for i in 0:5
tup = ntuple((_ -> NoTangent()), i)
ndp = NDP(tup)
@test ndp === @inferred NDP(tup)
@test tup === @inferred ndp(:arbitrary)
@test_throws MethodError ndp()
@test_throws MethodError ndp(1, 2)
end
end

@testset "issue #678: identical pullback objects" begin
issue_678_f(::Any) = nothing
issue_678_g(::Any) = nothing
issue_678_h(::Any...) = nothing
@non_differentiable issue_678_f(::Any)
@non_differentiable issue_678_g(::Any)
@non_differentiable issue_678_h(::Any...)
@test (
last(rrule(issue_678_f, 0.1)) ===
last(rrule(issue_678_g, 0.2)) ===
last(rrule(issue_678_h, 0.3))
)
end

@testset "two input one output function" begin
nondiff_2_1(x, y) = fill(7.5, 100)[x + y]
@non_differentiable nondiff_2_1(::Any, ::Any)
Expand Down

0 comments on commit 191eb47

Please sign in to comment.