Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.

Commit 9f59f80

Browse files
Merge pull request #100 from SciML/optjlintegration
Changes for getting SciML/Optimization.jl#789 passing
2 parents 8f0a067 + f0a527b commit 9f59f80

File tree

7 files changed

+185
-123
lines changed

7 files changed

+185
-123
lines changed

ext/OptimizationEnzymeExt.jl

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
204204
end
205205

206206
if cons !== nothing && cons_j == true && f.cons_j === nothing
207-
if num_cons > length(x)
208-
seeds = Enzyme.onehot(x)
209-
Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x))
210-
else
211-
seeds = Enzyme.onehot(zeros(eltype(x), num_cons))
212-
Jaccache = Tuple(zero(x) for i in 1:num_cons)
213-
end
207+
# if num_cons > length(x)
208+
seeds = Enzyme.onehot(x)
209+
Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x))
210+
# else
211+
# seeds = Enzyme.onehot(zeros(eltype(x), num_cons))
212+
# Jaccache = Tuple(zero(x) for i in 1:num_cons)
213+
# end
214214

215215
y = zeros(eltype(x), num_cons)
216216

@@ -219,27 +219,26 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
219219
Enzyme.make_zero!(Jaccache[i])
220220
end
221221
Enzyme.make_zero!(y)
222-
if num_cons > length(θ)
223-
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
224-
BatchDuplicated(θ, seeds), Const(p))
225-
for i in eachindex(θ)
226-
if J isa Vector
227-
J[i] = Jaccache[i][1]
228-
else
229-
copyto!(@view(J[:, i]), Jaccache[i])
230-
end
231-
end
232-
else
233-
Enzyme.autodiff(Enzyme.Reverse, f.cons, BatchDuplicated(y, seeds),
234-
BatchDuplicated(θ, Jaccache), Const(p))
235-
for i in 1:num_cons
236-
if J isa Vector
237-
J .= Jaccache[1]
238-
else
239-
copyto!(@view(J[i, :]), Jaccache[i])
240-
end
222+
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
223+
BatchDuplicated(θ, seeds), Const(p))
224+
for i in eachindex(θ)
225+
if J isa Vector
226+
J[i] = Jaccache[i][1]
227+
else
228+
copyto!(@view(J[:, i]), Jaccache[i])
241229
end
242230
end
231+
# else
232+
# Enzyme.autodiff(Enzyme.Reverse, f.cons, BatchDuplicated(y, seeds),
233+
# BatchDuplicated(θ, Jaccache), Const(p))
234+
# for i in 1:num_cons
235+
# if J isa Vector
236+
# J .= Jaccache[1]
237+
# else
238+
# J[i, :] = Jaccache[i]
239+
# end
240+
# end
241+
# end
243242
end
244243
elseif cons_j == true && cons !== nothing
245244
cons_j! = (J, θ) -> f.cons_j(J, θ, p)
@@ -397,11 +396,11 @@ end
397396
function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
398397
cache::OptimizationBase.ReInitCache,
399398
adtype::AutoEnzyme,
400-
num_cons = 0)
399+
num_cons = 0; kwargs...)
401400
p = cache.p
402401
x = cache.u0
403402

404-
return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons)
403+
return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons; kwargs...)
405404
end
406405

407406
function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x,
@@ -676,11 +675,11 @@ end
676675
function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
677676
cache::OptimizationBase.ReInitCache,
678677
adtype::AutoEnzyme,
679-
num_cons = 0)
678+
num_cons = 0; kwargs...)
680679
p = cache.p
681680
x = cache.u0
682681

683-
return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons)
682+
return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons; kwargs...)
684683
end
685684

686685
end

ext/OptimizationMTKExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ end
5757

5858
function OptimizationBase.instantiate_function(
5959
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
60-
adtype::AutoSparse{<:AutoSymbolics}, num_cons = 0,
60+
adtype::AutoSparse{<:AutoSymbolics}, num_cons = 0;
6161
g = false, h = false, hv = false, fg = false, fgh = false,
6262
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
6363
lag_h = false)
@@ -107,7 +107,7 @@ end
107107

108108
function OptimizationBase.instantiate_function(
109109
f::OptimizationFunction{true}, x, adtype::AutoSymbolics, p,
110-
num_cons = 0, g = false, h = false, hv = false, fg = false, fgh = false,
110+
num_cons = 0; g = false, h = false, hv = false, fg = false, fgh = false,
111111
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
112112
lag_h = false)
113113
p = isnothing(p) ? SciMLBase.NullParameters() : p
@@ -155,7 +155,7 @@ end
155155

156156
function OptimizationBase.instantiate_function(
157157
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
158-
adtype::AutoSymbolics, num_cons = 0,
158+
adtype::AutoSymbolics, num_cons = 0;
159159
g = false, h = false, hv = false, fg = false, fgh = false,
160160
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
161161
lag_h = false)

ext/OptimizationZygoteExt.jl

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ function OptimizationBase.instantiate_function(
220220
if f.lag_h === nothing && cons !== nothing && lag_h == true
221221
lag_extras = prepare_hessian(
222222
lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons)))
223-
lag_hess_prototype = zeros(Bool, length(x), length(x))
223+
lag_hess_prototype = zeros(Bool, length(x) + num_cons + 1, length(x) + num_cons + 1)
224224

225225
function lag_h!(H::AbstractMatrix, θ, σ, λ)
226226
if σ == zero(eltype(θ))
@@ -232,13 +232,11 @@ function OptimizationBase.instantiate_function(
232232
end
233233
end
234234

235-
function lag_h!(h, θ, σ, λ)
236-
H = eltype(θ).(lag_hess_prototype)
237-
hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras)
235+
function lag_h!(h::AbstractVector, θ, σ, λ)
236+
H = hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)
238237
k = 0
239-
rows, cols, _ = findnz(H)
240-
for (i, j) in zip(rows, cols)
241-
if i <= j
238+
for i in 1:length(θ)
239+
for j in 1:i
242240
k += 1
243241
h[k] = H[i, j]
244242
end
@@ -256,7 +254,7 @@ function OptimizationBase.instantiate_function(
256254
1:length(θ), 1:length(θ)])
257255
end
258256
end
259-
257+
260258
function lag_h!(h::AbstractVector, θ, σ, λ, p)
261259
global _p = p
262260
H = hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)
@@ -294,21 +292,20 @@ end
294292

295293
function OptimizationBase.instantiate_function(
296294
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
297-
adtype::ADTypes.AutoZygote, num_cons = 0;
298-
g = false, h = false, hv = false, fg = false, fgh = false,
299-
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false)
295+
adtype::ADTypes.AutoZygote, num_cons = 0; kwargs...)
300296
x = cache.u0
301297
p = cache.p
302298

303299
return OptimizationBase.instantiate_function(
304-
f, x, adtype, p, num_cons; g, h, hv, fg, fgh, cons_j, cons_vjp, cons_jvp, cons_h)
300+
f, x, adtype, p, num_cons; kwargs...)
305301
end
306302

307303
function OptimizationBase.instantiate_function(
308304
f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AutoZygote},
309305
p = SciMLBase.NullParameters(), num_cons = 0;
310306
g = false, h = false, hv = false, fg = false, fgh = false,
311-
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false)
307+
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
308+
lag_h = false)
312309
function _f(θ)
313310
return f.f(θ, p)[1]
314311
end
@@ -335,7 +332,7 @@ function OptimizationBase.instantiate_function(
335332
grad = nothing
336333
end
337334

338-
if fg == true && f.fg !== nothing
335+
if fg == true && f.fg === nothing
339336
if g == false
340337
extras_grad = prepare_gradient(_f, adtype.dense_ad, x)
341338
end
@@ -361,7 +358,7 @@ function OptimizationBase.instantiate_function(
361358

362359
hess_sparsity = f.hess_prototype
363360
hess_colors = f.hess_colorvec
364-
if f.hess === nothing
361+
if h == true && f.hess === nothing
365362
extras_hess = prepare_hessian(_f, soadtype, x) #placeholder logic, can be made much better
366363
function hess(res, θ)
367364
hessian!(_f, res, soadtype, θ, extras_hess)
@@ -384,7 +381,7 @@ function OptimizationBase.instantiate_function(
384381
hess = nothing
385382
end
386383

387-
if fgh == true && f.fgh !== nothing
384+
if fgh == true && f.fgh === nothing
388385
function fgh!(G, H, θ)
389386
(y, _, _) = value_derivative_and_second_derivative!(_f, G, H, θ, extras_hess)
390387
return y
@@ -406,7 +403,7 @@ function OptimizationBase.instantiate_function(
406403
fgh! = nothing
407404
end
408405

409-
if hv == true && f.hv !== nothing
406+
if hv == true && f.hv === nothing
410407
extras_hvp = prepare_hvp(_f, soadtype.dense_ad, x, zeros(eltype(x), size(x)))
411408
function hv!(H, θ, v)
412409
hvp!(_f, H, soadtype.dense_ad, θ, v, extras_hvp)
@@ -443,7 +440,7 @@ function OptimizationBase.instantiate_function(
443440
θ = augvars[1:length(x)]
444441
σ = augvars[length(x) + 1]
445442
λ = augvars[(length(x) + 2):end]
446-
return σ * _f(θ) + dot(λ, cons(θ))
443+
return σ * _f(θ) + dot(λ, cons_oop(θ))
447444
end
448445
end
449446

@@ -466,7 +463,8 @@ function OptimizationBase.instantiate_function(
466463
end
467464

468465
if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing
469-
extras_pullback = prepare_pullback(cons_oop, adtype, x)
466+
extras_pullback = prepare_pullback(
467+
cons_oop, adtype.dense_ad, x, ones(eltype(x), num_cons))
470468
function cons_vjp!(J, θ, v)
471469
pullback!(cons_oop, J, adtype.dense_ad, θ, v, extras_pullback)
472470
end
@@ -477,7 +475,8 @@ function OptimizationBase.instantiate_function(
477475
end
478476

479477
if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing
480-
extras_pushforward = prepare_pushforward(cons_oop, adtype, x)
478+
extras_pushforward = prepare_pushforward(
479+
cons_oop, adtype.dense_ad, x, ones(eltype(x), length(x)))
481480
function cons_jvp!(J, θ, v)
482481
pushforward!(cons_oop, J, adtype.dense_ad, θ, v, extras_pushforward)
483482
end
@@ -510,10 +509,11 @@ function OptimizationBase.instantiate_function(
510509
end
511510

512511
lag_hess_prototype = f.lag_hess_prototype
513-
if cons !== nothing && cons_h == true && f.lag_h === nothing
512+
lag_hess_colors = f.lag_hess_colorvec
513+
if cons !== nothing && f.lag_h === nothing && lag_h == true
514514
lag_extras = prepare_hessian(
515515
lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons)))
516-
lag_hess_prototype = lag_extras.coloring_result.S[1:length(θ), 1:length(θ)]
516+
lag_hess_prototype = lag_extras.coloring_result.S[1:length(x), 1:length(x)]
517517
lag_hess_colors = lag_extras.coloring_result.color
518518

519519
function lag_h!(H::AbstractMatrix, θ, σ, λ)
@@ -587,14 +587,11 @@ end
587587

588588
function OptimizationBase.instantiate_function(
589589
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
590-
adtype::ADTypes.AutoSparse{<:AutoZygote}, num_cons = 0;
591-
g = false, h = false, hv = false, fg = false, fgh = false,
592-
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false)
590+
adtype::ADTypes.AutoSparse{<:AutoZygote}, num_cons = 0; kwargs...)
593591
x = cache.u0
594592
p = cache.p
595593

596-
return OptimizationBase.instantiate_function(
597-
f, x, adtype, p, num_cons; g, h, hv, fg, fgh, cons_j, cons_vjp, cons_jvp, cons_h)
594+
return OptimizationBase.instantiate_function(f, x, adtype, p, num_cons; kwargs...)
598595
end
599596

600597
end

src/OptimizationDIExt.jl

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ function instantiate_function(
104104
hess = nothing
105105
end
106106

107-
if fgh == true && f.fgh !== nothing
107+
if fgh == true && f.fgh === nothing
108108
function fgh!(G, H, θ)
109109
(y, _, _) = value_derivative_and_second_derivative!(
110110
_f, G, H, soadtype, θ, extras_hess)
@@ -229,7 +229,7 @@ function instantiate_function(
229229
if cons !== nothing && lag_h == true && f.lag_h === nothing
230230
lag_extras = prepare_hessian(
231231
lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons)))
232-
lag_hess_prototype = zeros(Bool, length(x), length(x))
232+
lag_hess_prototype = zeros(Bool, length(x) + num_cons + 1, length(x) + num_cons + 1)
233233

234234
function lag_h!(H::AbstractMatrix, θ, σ, λ)
235235
if σ == zero(eltype(θ))
@@ -263,7 +263,7 @@ function instantiate_function(
263263
1:length(θ), 1:length(θ)])
264264
end
265265
end
266-
266+
267267
function lag_h!(h::AbstractVector, θ, σ, λ, p)
268268
global _p = p
269269
H = hessian(lagrangian, soadtype, vcat(θ, [σ], λ), lag_extras)
@@ -301,16 +301,12 @@ end
301301

302302
function instantiate_function(
303303
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
304-
adtype::ADTypes.AbstractADType, num_cons = 0,
305-
g = false, h = false, hv = false, fg = false, fgh = false,
306-
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
307-
lag_h = false)
304+
adtype::ADTypes.AbstractADType, num_cons = 0;
305+
kwargs...)
308306
x = cache.u0
309307
p = cache.p
310308

311-
return instantiate_function(f, x, adtype, p, num_cons; g = g, h = h, hv = hv,
312-
fg = fg, fgh = fgh, cons_j = cons_j, cons_vjp = cons_vjp, cons_jvp = cons_jvp,
313-
cons_h = cons_h, lag_h = lag_h)
309+
return instantiate_function(f, x, adtype, p, num_cons; kwargs...)
314310
end
315311

316312
function instantiate_function(
@@ -392,7 +388,7 @@ function instantiate_function(
392388
hess = nothing
393389
end
394390

395-
if fgh == true && f.fgh !== nothing
391+
if fgh == true && f.fgh === nothing
396392
function fgh!(θ)
397393
(y, G, H) = value_derivative_and_second_derivative(_f, adtype, θ, extras_hess)
398394
return y, G, H
@@ -511,7 +507,7 @@ function instantiate_function(
511507
if cons !== nothing && lag_h == true && f.lag_h === nothing
512508
lag_extras = prepare_hessian(
513509
lagrangian, soadtype, vcat(x, [one(eltype(x))], ones(eltype(x), num_cons)))
514-
lag_hess_prototype = zeros(Bool, length(x), length(x))
510+
lag_hess_prototype = zeros(Bool, length(x) + num_cons + 1, length(x) + num_cons + 1)
515511

516512
function lag_h!(θ, σ, λ)
517513
if σ == zero(eltype(θ))
@@ -558,9 +554,9 @@ end
558554

559555
function instantiate_function(
560556
f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache,
561-
adtype::ADTypes.AbstractADType, num_cons = 0)
557+
adtype::ADTypes.AbstractADType, num_cons = 0; kwargs...)
562558
x = cache.u0
563559
p = cache.p
564560

565-
return instantiate_function(f, x, adtype, p, num_cons)
561+
return instantiate_function(f, x, adtype, p, num_cons; kwargs...)
566562
end

0 commit comments

Comments
 (0)