Skip to content

Commit b8857c0

Browse files
Merge branch 'master' of https://github.com/SciML/Optimization.jl into refactor-lbfgs
2 parents 5bea599 + 2a90ec0 commit b8857c0

File tree

3 files changed

+34
-9
lines changed

3 files changed

+34
-9
lines changed

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,17 @@ function SciMLBase.__solve(cache::OptimizationCache{
139139
error("Use OptimizationFunction to pass the derivatives or automatically generate them with one of the autodiff backends")
140140

141141
function _cb(trace)
142-
metadata = decompose_trace(trace).metadata
142+
trace_state = decompose_trace(trace)
143+
metadata = trace_state.metadata
143144
θ = metadata[cache.opt isa Optim.NelderMead ? "centroid" : "x"]
144-
opt_state = Optimization.OptimizationState(iter = trace.iteration,
145+
opt_state = Optimization.OptimizationState(iter = trace_state.iteration,
145146
u = θ,
146147
p = cache.p,
147-
objective = trace.value,
148+
objective = trace_state.value,
148149
grad = get(metadata, "g(x)", nothing),
149150
hess = get(metadata, "h(x)", nothing),
150151
original = trace)
151-
cb_call = cache.callback(opt_state, trace.value)
152+
cb_call = cache.callback(opt_state, trace_state.value)
152153
if !(cb_call isa Bool)
153154
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
154155
end
@@ -257,18 +258,19 @@ function SciMLBase.__solve(cache::OptimizationCache{
257258
local x, cur, state
258259

259260
function _cb(trace)
260-
metadata = decompose_trace(trace).metadata
261+
trace_state = decompose_trace(trace)
262+
metadata = trace_state.metadata
261263
θ = !(cache.opt isa Optim.SAMIN) && cache.opt.method == Optim.NelderMead() ?
262264
metadata["centroid"] :
263265
metadata["x"]
264-
opt_state = Optimization.OptimizationState(iter = trace.iteration,
266+
opt_state = Optimization.OptimizationState(iter = trace_state.iteration,
265267
u = θ,
266268
p = cache.p,
267-
objective = trace.value,
269+
objective = trace_state.value,
268270
grad = get(metadata, "g(x)", nothing),
269271
hess = get(metadata, "h(x)", nothing),
270272
original = trace)
271-
cb_call = cache.callback(opt_state, trace.value)
273+
cb_call = cache.callback(opt_state, trace_state.value)
272274
if !(cb_call isa Bool)
273275
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
274276
end

lib/OptimizationOptimJL/test/runtests.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,4 +213,27 @@ end
213213
sol = Optimization.solve!(cache)
214214
@test sol.u[2.0] atol=1e-3
215215
end
216+
217+
@testset "store_trace=true" begin
218+
# Test that store_trace=true works without throwing errors (issue #990)
219+
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
220+
x0 = zeros(2)
221+
_p = [1.0, 100.0]
222+
223+
# Test with NelderMead
224+
prob = OptimizationProblem(rosenbrock, x0, _p)
225+
sol = solve(prob, NelderMead(), store_trace = true)
226+
@test sol isa Any # just test it doesn't throw
227+
228+
# Test with Fminbox(NelderMead)
229+
optprob = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff())
230+
prob = OptimizationProblem(optprob, x0, _p, lb = [-1.0, -1.0], ub = [0.8, 0.8])
231+
sol = solve(prob, Optim.Fminbox(NelderMead()), store_trace = true)
232+
@test sol isa Any # just test it doesn't throw
233+
234+
# Test with BFGS
235+
prob = OptimizationProblem(optprob, x0, _p)
236+
sol = solve(prob, BFGS(), store_trace = true)
237+
@test sol isa Any # just test it doesn't throw
238+
end
216239
end

lib/OptimizationOptimisers/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimizationOptimisers"
22
uuid = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
33
authors = ["Vaibhav Dixit <vaibhavyashdixit@gmail.com> and contributors"]
4-
version = "0.3.8"
4+
version = "0.3.9"
55

66
[deps]
77
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"

0 commit comments

Comments
 (0)