Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update documentation #681

Merged
merged 7 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
851 changes: 469 additions & 382 deletions docs/Manifest.toml

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/src/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ For example, in `access(xs, n) = xs[n]`, the derivative of `access` with respect
When no custom `frule` or `rrule` exists, if you try to call one of those, it will return `nothing` by default.
As a result, you may encounter errors like

```julia
```plain
MethodError: no method matching iterate(::Nothing)
```

Expand Down
2 changes: 1 addition & 1 deletion docs/src/ad_author/opt_out.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ We provide two ways to know that a rule has been opted out of.
`@opt_out` defines a `frule` or `rrule` matching the signature that returns `nothing`.

If you are in a position to generate code, in response to values returned by function calls then you can do something like:
```@julia
```julia
res = rrule(f, xs)
if res === nothing
y, pullback = perform_ad_via_decomposition(r, xs) # do AD without hitting the rrule
Expand Down
74 changes: 37 additions & 37 deletions docs/src/design/changing_the_primal.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ What about using `sincos`?
```@raw html
<details open><summary>Example for `sin`</summary>
```
```julia
```julia-repl
julia> using BenchmarkTools

julia> @btime sin(x) setup=(x=rand());
Expand All @@ -76,7 +76,7 @@ julia> 3.838 + 4.795
8.633
```
vs computing both together:
```julia
```julia-repl
julia> @btime sincos(x) setup=(x=rand());
6.028 ns (0 allocations: 0 bytes)
```
Expand All @@ -96,7 +96,7 @@ So we can save time, if we can reuse that `exp(x)`.
<details open><summary>Example for the logistic sigmoid</summary>
```
If we have to computing separately:
```julia
```julia-repl
julia> @btime 1/(1+exp(x)) setup=(x=rand());
5.622 ns (0 allocations: 0 bytes)

Expand All @@ -108,7 +108,7 @@ julia> 5.622 + 6.036
```

vs reusing `exp(x)`:
```julia
```julia-repl
julia> @btime exp(x) setup=(x=rand());
5.367 ns (0 allocations: 0 bytes)

Expand Down Expand Up @@ -148,8 +148,8 @@ x̄ = pullback_at(f, x, y, ȳ, intermediates)
```
```julia
function augmented_primal(::typeof(sin), x)
y, cx = sincos(x)
return y, (; cx=cx) # use a NamedTuple for the intermediates
y, cx = sincos(x)
return y, (; cx=cx) # use a NamedTuple for the intermediates
end

pullback_at(::typeof(sin), x, y, ȳ, intermediates) = ȳ * intermediates.cx
Expand All @@ -163,9 +163,9 @@ pullback_at(::typeof(sin), x, y, ȳ, intermediates) = ȳ * intermediates.cx
```
```julia
function augmented_primal(::typeof(σ), x)
ex = exp(x)
y = ex / (1 + ex)
return y, (; ex=ex) # use a NamedTuple for the intermediates
ex = exp(x)
y = ex / (1 + ex)
return y, (; ex=ex) # use a NamedTuple for the intermediates
end

pullback_at(::typeof(σ), x, y, ȳ, intermediates) = ȳ * y / (1 + intermediates.ex)
Expand All @@ -189,8 +189,8 @@ And storing all these things on the tape — inputs, outputs, sensitivities, int
What if we generalized the idea of the `intermediate` named tuple, and had `augmented_primal` return a struct that just held anything we might want put on the tape.
```julia
struct PullbackMemory{P, S}
primal_function::P
state::S
primal_function::P
state::S
end
# convenience constructor:
PullbackMemory(primal_function; state...) = PullbackMemory(primal_function, state)
Expand All @@ -211,8 +211,8 @@ which is much cleaner.
```
```julia
function augmented_primal(::typeof(sin), x)
y, cx = sincos(x)
return y, PullbackMemory(sin; cx=cx)
y, cx = sincos(x)
return y, PullbackMemory(sin; cx=cx)
end

pullback_at(pb::PullbackMemory{typeof(sin)}, ȳ) = ȳ * pb.cx
Expand All @@ -226,9 +226,9 @@ pullback_at(pb::PullbackMemory{typeof(sin)}, ȳ) = ȳ * pb.cx
```
```julia
function augmented_primal(::typeof(σ), x)
ex = exp(x)
y = ex / (1 + ex)
return y, PullbackMemory(σ; y=y, ex=ex)
ex = exp(x)
y = ex / (1 + ex)
return y, PullbackMemory(σ; y=y, ex=ex)
end

pullback_at(pb::PullbackMemory{typeof(σ)}, ȳ) = ȳ * pb.y / (1 + pb.ex)
Expand Down Expand Up @@ -256,8 +256,8 @@ x̄ = pb(ȳ)
```
```julia
function augmented_primal(::typeof(sin), x)
y, cx = sincos(x)
return y, PullbackMemory(sin; cx=cx)
y, cx = sincos(x)
return y, PullbackMemory(sin; cx=cx)
end
(pb::PullbackMemory{typeof(sin)})(ȳ) = ȳ * pb.cx
```
Expand All @@ -271,9 +271,9 @@ end
```
```julia
function augmented_primal(::typeof(σ), x)
ex = exp(x)
y = ex / (1 + ex)
return y, PullbackMemory(σ; y=y, ex=ex)
ex = exp(x)
y = ex / (1 + ex)
return y, PullbackMemory(σ; y=y, ex=ex)
end

(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y / (1 + pb.ex)
Expand All @@ -295,16 +295,16 @@ Let's go back and think about the changes we would have make to go from our orig
To rewrite that original formulation in the new pullback form we have:
```julia
function augmented_primal(::typeof(sin), x)
y = sin(x)
return y, PullbackMemory(sin; x=x)
y = sin(x)
return y, PullbackMemory(sin; x=x)
end
(pb::PullbackMemory)(ȳ) = ȳ * cos(pb.x)
```
To go from that to:
```julia
function augmented_primal(::typeof(sin), x)
y, cx = sincos(x)
return y, PullbackMemory(sin; cx=cx)
y, cx = sincos(x)
return y, PullbackMemory(sin; cx=cx)
end
(pb::PullbackMemory)(ȳ) = ȳ * pb.cx
```
Expand All @@ -317,17 +317,17 @@ end
```
```julia
function augmented_primal(::typeof(σ), x)
y = σ(x)
return y, PullbackMemory(σ; y=y, x=x)
y = σ(x)
return y, PullbackMemory(σ; y=y, x=x)
end
(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y * σ(-pb.x)
```
to get to:
```julia
function augmented_primal(::typeof(σ), x)
ex = exp(x)
y = ex/(1 + ex)
return y, PullbackMemory(σ; y=y, ex=ex)
ex = exp(x)
y = ex/(1 + ex)
return y, PullbackMemory(σ; y=y, ex=ex)
end
(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y/(1 + pb.ex)
```
Expand Down Expand Up @@ -356,9 +356,9 @@ Replacing `PullbackMemory` with a closure that works the same way lets us avoid
```
```julia
function augmented_primal(::typeof(sin), x)
y, cx = sincos(x)
pb = ȳ -> cx * ȳ # pullback closure. closes over `cx`
return y, pb
y, cx = sincos(x)
pb = ȳ -> cx * ȳ # pullback closure. closes over `cx`
return y, pb
end
```
```@raw html
Expand All @@ -370,10 +370,10 @@ end
```
```julia
function augmented_primal(::typeof(σ), x)
ex = exp(x)
y = ex / (1 + ex)
pb = ȳ -> ȳ * y / (1 + ex) # pullback closure. closes over `y` and `ex`
return y, pb
ex = exp(x)
y = ex / (1 + ex)
pb = ȳ -> ȳ * y / (1 + ex) # pullback closure. closes over `y` and `ex`
return y, pb
end
```
```@raw html
Expand Down
6 changes: 3 additions & 3 deletions docs/src/design/many_tangents.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Structural tangents are derived from the structure of the input.
Either automatically, as part of the AD, or manually, as part of a custom rule.

Consider the structure of `DateTime`:
```julia
```julia-repl
julia> dump(now())
DateTime
instant: UTInstant{Millisecond}
Expand Down Expand Up @@ -83,15 +83,15 @@ Where there is no natural tangent type for the outermost type but there is for s

Consider if we had a representation of a country's GDP as output by some continuous time model like a Gaussian Process, where that representation is as a sequence of `TimeSample`s
structured as follows:
```julia
```julia-repl
julia> struct TimeSample
time::DateTime
value::Float64
end
```

We can look at its structure:
```julia
```julia-repl
julia> dump(TimeSample(now(), 2.6e9))
TimeSample
time: DateTime
Expand Down
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ end
# output

```

```jldoctest index
#### Find dfoo/dx via rrules
#### First the forward pass, gathering up the pullbacks
Expand Down
2 changes: 1 addition & 1 deletion docs/src/rule_author/converting_zygoterules.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Converting ZygoteRules.@adjoint to `rrule`s
# Converting `ZygoteRules.@adjoint` to `rrule`s

[ZygoteRules.jl](https://github.com/FluxML/ZygoteRules.jl) is a legacy package similar to ChainRulesCore but supporting [Zygote.jl](https://github.com/FluxML/Zygote.jl) only.

Expand Down
3 changes: 2 additions & 1 deletion docs/src/rule_author/example.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ end
```

We can check this rule against a finite-differences approach using [`ChainRulesTestUtils`](https://github.com/JuliaDiff/ChainRulesTestUtils.jl):
```julia
```julia-repl
julia> using ChainRulesTestUtils

julia> test_rrule(foo_mul, Foo(rand(3, 3), 3.0), rand(3, 3))
Test Summary: | Pass Total
test_rrule: foo_mul on Foo{Float64},Matrix{Float64} | 10 10
Expand Down
23 changes: 13 additions & 10 deletions docs/src/rule_author/which_functions_need_rules.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ function addone(a::AbstractArray)
end
```
complains that
```julia
```julia-repl
julia> using Zygote

julia> gradient(addone, a)
ERROR: Mutating arrays is not supported
```
Expand All @@ -50,7 +51,7 @@ function ChainRules.rrule(::typeof(addone), a)
end
```
the gradient can be evaluated:
```julia
```julia-repl
julia> gradient(addone, a)
([1.0, 1.0, 1.0],)
```
Expand Down Expand Up @@ -86,7 +87,7 @@ function exception(x)
end
```
does not work
```julia
```julia-repl
julia> gradient(exception, 3.0)
ERROR: Compiling Tuple{typeof(exception),Int64}: try/catch is not supported.
```
Expand All @@ -101,7 +102,7 @@ function ChainRulesCore.rrule(::typeof(exception), x)
end
```

```julia
```julia-repl
julia> gradient(exception, 3.0)
(6.0,)
```
Expand All @@ -123,9 +124,11 @@ function mse(y, ŷ)
end
```
takes a lot longer to AD through
```julia
julia> y = rand(30)
julia> ŷ = rand(30)
```julia-repl
julia> y = rand(30);

julia> ŷ = rand(30);

julia> @btime gradient(mse, $y, $ŷ)
38.180 μs (993 allocations: 65.00 KiB)
```
Expand All @@ -142,7 +145,7 @@ function ChainRules.rrule(::typeof(mse), x, x̂)
end
```
which is much faster
```julia
```julia-repl
julia> @btime gradient(mse, $y, $ŷ)
143.697 ns (2 allocations: 672 bytes)
```
Expand All @@ -159,7 +162,7 @@ function sum3(array)
return x+y+z
end
```
```julia
```julia-repl
julia> @btime gradient(sum3, rand(30))
424.510 ns (9 allocations: 2.06 KiB)
```
Expand All @@ -176,7 +179,7 @@ function ChainRulesCore.rrule(::typeof(sum3), a)
end
```
turns out to be significantly faster
```julia
```julia-repl
julia> @btime gradient(sum3, rand(30))
192.818 ns (3 allocations: 784 bytes)
```
6 changes: 3 additions & 3 deletions docs/src/rule_author/writing_good_rules.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ Because `typeof(Bar)` is `DataType`, using this to define an `rrule`/`frule` wil

You can check which to use with `Core.Typeof`:

```julia
```julia-repl
julia> function foo end
foo (generic function with 0 methods)

Expand Down Expand Up @@ -254,7 +254,7 @@ function ChainRulesCore.rrule(::typeof(double_it), x)
end
```
Ends up infering a return type of `Any`
```julia
```julia-repl
julia> _, pullback = rrule(double_it, [2.0, 3.0])
([4.0, 6.0], var"#double_it_pullback#8"(Core.Box(var"#double_it_pullback#8"(#= circular reference @-2 =#))))

Expand Down Expand Up @@ -289,7 +289,7 @@ function ChainRulesCore.rrule(::typeof(double_it), x)
end
```
This infers just fine:
```julia
```julia-repl
julia> _, pullback = rrule(double_it, [2.0, 3.0])
([4.0, 6.0], _double_it_pullback)

Expand Down
Loading
Loading