From b72df42c543e57315c136862b93226f3185ea8e7 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 5 Dec 2023 17:08:34 +1300 Subject: [PATCH] =?UTF-8?q?add=20illustration=20of=20=E2=88=82self=20in=20?= =?UTF-8?q?the=20maths/propagators=20section?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/src/maths/propagators.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/src/maths/propagators.md b/docs/src/maths/propagators.md index aba531f55..acfc43a44 100644 --- a/docs/src/maths/propagators.md +++ b/docs/src/maths/propagators.md @@ -179,6 +179,24 @@ So every `pushforward` takes in an extra argument, which is ignored unless the o It is common to write `function foo_pushforward(_, Δargs...)` in the case when `foo` does not have fields. Similarly every `pullback` returns an extra `∂self`, which for things without fields is `NoTangent()`, indicating there are no fields within the function itself. +Here's an example showing how to define `∂self` in an `rrule` when the primal function has +internal fields (implicit arguments): + +```julia +struct Multiplier{T} + x::T +end +(m::Multiplier)(y) = m.x * y + +function ChainRulesCore.rrule(m::Multiplier, y) + product = m(y) + function pullback(Δproduct) + ∂self = Tangent{typeof(m)}(; x = Δproduct * y') + ∂y = m.x' * Δproduct + return ∂self, ∂y + return product, pullback +end +``` ### Pushforward / Pullback summary