Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump compats and update tutorials for Optimization v4 #950

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Boltz = "1"
ChainRulesCore = "1"
ComponentArrays = "0.15.17"
ConcreteStructs = "0.2"
DataInterpolations = "5, 6"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't remove 5

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That causes conflict too, with Symbolics

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just saw your comment on the compat PR in Boltz, I'll find a workaround

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would 5 be needed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

DataInterpolations = "6.4"
DelayDiffEq = "5.47.3"
DiffEqCallbacks = "3.6.2"
Distances = "0.10.11"
Expand All @@ -54,9 +54,9 @@ LuxLib = "1.2"
NNlib = "0.9.22"
OneHotArrays = "0.2.5"
Optimisers = "0.3"
Optimization = "3.25.0"
OptimizationOptimJL = "0.3.0"
OptimizationOptimisers = "0.2.1"
Optimization = "4"
OptimizationOptimJL = "0.4"
OptimizationOptimisers = "0.3"
OrdinaryDiffEq = "6.76.0"
Printf = "1.10"
Random = "1.10"
Expand Down
8 changes: 4 additions & 4 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ MLUtils = "0.4"
NNlib = "0.9"
OneHotArrays = "0.2"
Optimisers = "0.3"
Optimization = "3.9"
OptimizationOptimJL = "0.2, 0.3"
OptimizationOptimisers = "0.2"
OptimizationPolyalgorithms = "0.2"
Optimization = "4"
OptimizationOptimJL = "0.4"
OptimizationOptimisers = "0.3"
OptimizationPolyalgorithms = "0.3"
OrdinaryDiffEq = "6.31"
Plots = "1.36"
Printf = "1"
Expand Down
34 changes: 17 additions & 17 deletions docs/src/examples/augmented_neural_ode.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ function plot_contour(model, ps, st, npoints = 300)
return contour(x, y, sol; fill = true, linewidth = 0.0)
end

loss_node(model, x, y, ps, st) = mean((first(model(x, ps, st)) .- y) .^ 2)
loss_node(model, data, ps, st) = mean((first(model(data[1], ps, st)) .- data[2]) .^ 2)

dataloader = concentric_sphere(
2, (0.0f0, 2.0f0), (3.0f0, 4.0f0), 2000, 2000; batch_size = 256)

iter = 0
cb = function (ps, l)
cb = function (state, l)
global iter
iter += 1
if iter % 10 == 0
Expand All @@ -87,15 +87,15 @@ end
model, ps, st = construct_model(1, 2, 64, 0)
opt = OptimizationOptimisers.Adam(0.005)

loss_node(model, dataloader.data[1], dataloader.data[2], ps, st)
loss_node(model, (dataloader.data[1], dataloader.data[2]), ps, st)

println("Training Neural ODE")

optfunc = OptimizationFunction(
(x, p, data, target) -> loss_node(model, data, target, x, st),
(x, data) -> loss_node(model, data, x, st),
Optimization.AutoZygote())
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev)
res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb)
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader)
res = solve(optprob, opt; callback = cb, epochs = 1000)

plt_node = plot_contour(model, res.u, st)

Expand All @@ -106,10 +106,10 @@ println()
println("Training Augmented Neural ODE")

optfunc = OptimizationFunction(
(x, p, data, target) -> loss_node(model, data, target, x, st),
(x, data) -> loss_node(model, data, x, st),
Optimization.AutoZygote())
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev)
res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb)
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader)
res = solve(optprob, opt; callback = cb, epochs = 1000)

plot_contour(model, res.u, st)
```
Expand Down Expand Up @@ -229,7 +229,7 @@ We use the L2 distance between the model prediction `model(x)` and the actual pr
optimization objective.

```@example augneuralode
loss_node(model, x, y, ps, st) = mean((first(model(x, ps, st)) .- y) .^ 2)
loss_node(model, data, ps, st) = mean((first(model(data[1], ps, st)) .- data[2]) .^ 2)
```

#### Dataset
Expand All @@ -248,7 +248,7 @@ Additionally, we define a callback function which displays the total loss at spe

```@example augneuralode
iter = 0
cb = function (ps, l)
cb = function (state, l)
global iter
iter += 1
if iter % 10 == 0
Expand Down Expand Up @@ -276,10 +276,10 @@ for `20` epochs.
model, ps, st = construct_model(1, 2, 64, 0)

optfunc = OptimizationFunction(
(x, p, data, target) -> loss_node(model, data, target, x, st),
(x, data) -> loss_node(model, data, x, st),
Optimization.AutoZygote())
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev)
res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb)
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader)
res = solve(optprob, opt; callback = cb, epochs = 1000)

plot_contour(model, res.u, st)
```
Expand All @@ -297,10 +297,10 @@ a function which can be expressed by the neural ode. For more details and proofs
model, ps, st = construct_model(1, 2, 64, 1)

optfunc = OptimizationFunction(
(x, p, data, target) -> loss_node(model, data, target, x, st),
(x, data) -> loss_node(model, data, x, st),
Optimization.AutoZygote())
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev)
res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb)
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader)
res = solve(optprob, opt; callback = cb, epochs = 1000)

plot_contour(model, res.u, st)
```
Expand Down
25 changes: 11 additions & 14 deletions docs/src/examples/hamiltonian_nn.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ The HNN predicts the gradients ``(\dot q, \dot p)`` given ``(q, p)``. Hence, we

```@example hamiltonian
using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random,
ComponentArrays, Optimization, OptimizationOptimisers, IterTools
ComponentArrays, Optimization, OptimizationOptimisers, MLUtils

t = range(0.0f0, 1.0f0; length = 1024)
π_32 = Float32(π)
Expand All @@ -87,12 +87,8 @@ dpdt = -2π_32 .* q_t
data = cat(q_t, p_t; dims = 1)
target = cat(dqdt, dpdt; dims = 1)
B = 256
NEPOCHS = 100
dataloader = ncycle(
((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))),
selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))))
for i in 1:(size(data, 2) ÷ B)),
NEPOCHS)
NEPOCHS = 1000
dataloader = DataLoader((data, target); batchsize = B)
```

### Training the HamiltonianNN
Expand All @@ -103,24 +99,25 @@ We parameterize the with a small MultiLayered Perceptron. HNNs are trained by o
hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (64, 1)); autodiff = AutoZygote())
ps, st = Lux.setup(Xoshiro(0), hnn)
ps_c = ps |> ComponentArray
hnn_stateful = StatefulLuxLayer{true}(hnn, ps_c, st)

opt = OptimizationOptimisers.Adam(0.01f0)

function loss_function(ps, data, target)
pred, st_ = hnn(data, ps, st)
function loss_function(ps, databatch)
(data, target) = databatch
pred = hnn_stateful(data, ps)
return mean(abs2, pred .- target), pred
end

function callback(ps, loss, pred)
function callback(state, loss)
println("[Hamiltonian NN] Loss: ", loss)
return false
end

opt_func = OptimizationFunction(
(ps, _, data, target) -> loss_function(ps, data, target), Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps_c)
opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps_c, dataloader)

res = solve(opt_prob, opt, dataloader; callback)
res = solve(opt_prob, opt; callback, epochs = NEPOCHS)

ps_trained = res.u
```
Expand Down
14 changes: 7 additions & 7 deletions docs/src/examples/mnist_conv_neural_ode.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,20 @@ end
# burn in accuracy
accuracy(m, ((img, lab),), ps, st)

function loss_function(ps, x, y)
function loss_function(ps, data)
(x, y) = data
pred, _ = m(x, ps, st)
return logitcrossentropy(pred, y), pred
return logitcrossentropy(pred, y)
end

# burn in loss
loss_function(ps, img, lab)
loss_function(ps, (img, lab))

opt = OptimizationOptimisers.Adam(0.005)
iter = 0

opt_func = OptimizationFunction(
(ps, _, x, y) -> loss_function(ps, x, y), Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps);
opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps, dataloader);

function callback(ps, l, pred)
global iter += 1
Expand All @@ -112,7 +112,7 @@ function callback(ps, l, pred)
end

# Train the NN-ODE and monitor the loss and weights.
res = Optimization.solve(opt_prob, opt, dataloader; maxiters = 5, callback)
res = Optimization.solve(opt_prob, opt; maxiters = 5, callback)
acc = accuracy(m, dataloader, res.u, st)
acc # hide
```
34 changes: 17 additions & 17 deletions docs/src/examples/mnist_neural_ode.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,29 +81,29 @@ end

accuracy(m, ((x_train1, y_train1),), ps, st) # burn in accuracy

function loss_function(ps, x, y)
function loss_function(ps, data)
(x, y) = data
pred, st_ = m(x, ps, st)
return logitcrossentropy(pred, y), pred
return logitcrossentropy(pred, y)
end

loss_function(ps, x_train1, y_train1) # burn in loss

opt = OptimizationOptimisers.Adam(0.05)
iter = 0

opt_func = OptimizationFunction(
(ps, _, x, y) -> loss_function(ps, x, y), Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps)
opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps, dataloader)

function callback(ps, l, pred)
function callback(state, l)
global iter += 1
iter % 10 == 0 &&
@info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))"
@info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))"
return false
end

# Train the NN-ODE and monitor the loss and weights.
res = Optimization.solve(opt_prob, opt, dataloader; callback, maxiters = 5)
res = Optimization.solve(opt_prob, opt; callback, maxiters = 5)
accuracy(m, dataloader, res.u, st)
```

Expand Down Expand Up @@ -285,12 +285,13 @@ final output of our model. `logitcrossentropy` takes in the prediction from our
model `model(x)` and compares it to actual output `y`:

```@example mnist
function loss_function(ps, x, y)
function loss_function(ps, data)
(x, y) = data
pred, st_ = m(x, ps, st)
return logitcrossentropy(pred, y), pred
return logitcrossentropy(pred, y)
end

loss_function(ps, x_train1, y_train1) # burn in loss
loss_function(ps, (x_train1, y_train1)) # burn in loss
```

#### Optimizer
Expand All @@ -309,14 +310,13 @@ This callback function is used to print both the training and testing accuracy a
```@example mnist
iter = 0

opt_func = OptimizationFunction(
(ps, _, x, y) -> loss_function(ps, x, y), Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps)
opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps, dataloader)

function callback(ps, l, pred)
function callback(state, l)
global iter += 1
iter % 10 == 0 &&
@info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))"
@info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))"
return false
end
```
Expand All @@ -329,6 +329,6 @@ for Neural ODE is given by `nn_ode.p`:

```@example mnist
# Train the NN-ODE and monitor the loss and weights.
res = Optimization.solve(opt_prob, opt, dataloader; callback, maxiters = 5)
res = Optimization.solve(opt_prob, opt; callback, maxiters = 5)
accuracy(m, dataloader, res.u, st)
```
19 changes: 12 additions & 7 deletions docs/src/examples/multiple_shooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps))
nn = Chain(x -> x .^ 3, Dense(2, 16, tanh), Dense(16, 2))
p_init, st = Lux.setup(rng, nn)

ps = ComponentArray(p_init)
pd, pax = getdata(ps), getaxes(ps)

neuralode = NeuralODE(nn, tspan, Tsit5(); saveat = tsteps)
prob_node = ODEProblem((u, p, t) -> nn(u, p, st)[1], u0, tspan, ComponentArray(p_init))

Expand All @@ -62,14 +65,13 @@ end

anim = Plots.Animation()
iter = 0
callback = function (p, l, preds; doplot = true)
function callback(state, l; doplot = true, prob_node = prob_node)
display(l)
global iter
iter += 1
if doplot && iter % 1 == 0
# plot the original data
plt = scatter(tsteps, ode_data[1, :]; label = "Data")

# plot the different predictions for individual shoot
plot_multiple_shoot(plt, preds, group_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how this is going to work if we don't have the predictions?


Expand All @@ -83,23 +85,26 @@ end
group_size = 3
continuity_term = 200

l1, preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
Tsit5(), group_size; continuity_term)

function loss_function(data, pred)
return sum(abs2, data - pred)
end

ps = ComponentArray(p_init)
pd, pax = getdata(ps), getaxes(ps)

function loss_multiple_shooting(p)
ps = ComponentArray(p, pax)
return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,

loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
Tsit5(), group_size; continuity_term)
global preds = currpred
return loss
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_multiple_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pd)
res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback)
res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback, maxiters = 5000)
gif(anim, "multiple_shooting.gif"; fps = 15)
```

Expand Down
Loading
Loading