@@ -220,7 +220,7 @@ function OptimizationBase.instantiate_function(
220
220
if f. lag_h === nothing && cons != = nothing && lag_h == true
221
221
lag_extras = prepare_hessian (
222
222
lagrangian, soadtype, vcat (x, [one (eltype (x))], ones (eltype (x), num_cons)))
223
- lag_hess_prototype = zeros (Bool, length (x), length (x))
223
+ lag_hess_prototype = zeros (Bool, length (x) + num_cons + 1 , length (x) + num_cons + 1 )
224
224
225
225
function lag_h! (H:: AbstractMatrix , θ, σ, λ)
226
226
if σ == zero (eltype (θ))
@@ -232,13 +232,11 @@ function OptimizationBase.instantiate_function(
232
232
end
233
233
end
234
234
235
- function lag_h! (h, θ, σ, λ)
236
- H = eltype (θ).(lag_hess_prototype)
237
- hessian! (x -> lagrangian (x, σ, λ), H, soadtype, θ, lag_extras)
235
+ function lag_h! (h:: AbstractVector , θ, σ, λ)
236
+ H = hessian (lagrangian, soadtype, vcat (θ, [σ], λ), lag_extras)
238
237
k = 0
239
- rows, cols, _ = findnz (H)
240
- for (i, j) in zip (rows, cols)
241
- if i <= j
238
+ for i in 1 : length (θ)
239
+ for j in 1 : i
242
240
k += 1
243
241
h[k] = H[i, j]
244
242
end
@@ -256,7 +254,7 @@ function OptimizationBase.instantiate_function(
256
254
1 : length (θ), 1 : length (θ)])
257
255
end
258
256
end
259
-
257
+
260
258
function lag_h! (h:: AbstractVector , θ, σ, λ, p)
261
259
global _p = p
262
260
H = hessian (lagrangian, soadtype, vcat (θ, [σ], λ), lag_extras)
@@ -294,21 +292,20 @@ end
294
292
295
293
function OptimizationBase. instantiate_function (
296
294
f:: OptimizationFunction{true} , cache:: OptimizationBase.ReInitCache ,
297
- adtype:: ADTypes.AutoZygote , num_cons = 0 ;
298
- g = false , h = false , hv = false , fg = false , fgh = false ,
299
- cons_j = false , cons_vjp = false , cons_jvp = false , cons_h = false )
295
+ adtype:: ADTypes.AutoZygote , num_cons = 0 ; kwargs... )
300
296
x = cache. u0
301
297
p = cache. p
302
298
303
299
return OptimizationBase. instantiate_function (
304
- f, x, adtype, p, num_cons; g, h, hv, fg, fgh, cons_j, cons_vjp, cons_jvp, cons_h )
300
+ f, x, adtype, p, num_cons; kwargs ... )
305
301
end
306
302
307
303
function OptimizationBase. instantiate_function (
308
304
f:: OptimizationFunction{true} , x, adtype:: ADTypes.AutoSparse{<:AutoZygote} ,
309
305
p = SciMLBase. NullParameters (), num_cons = 0 ;
310
306
g = false , h = false , hv = false , fg = false , fgh = false ,
311
- cons_j = false , cons_vjp = false , cons_jvp = false , cons_h = false )
307
+ cons_j = false , cons_vjp = false , cons_jvp = false , cons_h = false ,
308
+ lag_h = false )
312
309
function _f (θ)
313
310
return f. f (θ, p)[1 ]
314
311
end
@@ -335,7 +332,7 @@ function OptimizationBase.instantiate_function(
335
332
grad = nothing
336
333
end
337
334
338
- if fg == true && f. fg ! == nothing
335
+ if fg == true && f. fg = == nothing
339
336
if g == false
340
337
extras_grad = prepare_gradient (_f, adtype. dense_ad, x)
341
338
end
@@ -361,7 +358,7 @@ function OptimizationBase.instantiate_function(
361
358
362
359
hess_sparsity = f. hess_prototype
363
360
hess_colors = f. hess_colorvec
364
- if f. hess === nothing
361
+ if h == true && f. hess === nothing
365
362
extras_hess = prepare_hessian (_f, soadtype, x) # placeholder logic, can be made much better
366
363
function hess (res, θ)
367
364
hessian! (_f, res, soadtype, θ, extras_hess)
@@ -384,7 +381,7 @@ function OptimizationBase.instantiate_function(
384
381
hess = nothing
385
382
end
386
383
387
- if fgh == true && f. fgh ! == nothing
384
+ if fgh == true && f. fgh = == nothing
388
385
function fgh! (G, H, θ)
389
386
(y, _, _) = value_derivative_and_second_derivative! (_f, G, H, θ, extras_hess)
390
387
return y
@@ -406,7 +403,7 @@ function OptimizationBase.instantiate_function(
406
403
fgh! = nothing
407
404
end
408
405
409
- if hv == true && f. hv ! == nothing
406
+ if hv == true && f. hv = == nothing
410
407
extras_hvp = prepare_hvp (_f, soadtype. dense_ad, x, zeros (eltype (x), size (x)))
411
408
function hv! (H, θ, v)
412
409
hvp! (_f, H, soadtype. dense_ad, θ, v, extras_hvp)
@@ -443,7 +440,7 @@ function OptimizationBase.instantiate_function(
443
440
θ = augvars[1 : length (x)]
444
441
σ = augvars[length (x) + 1 ]
445
442
λ = augvars[(length (x) + 2 ): end ]
446
- return σ * _f (θ) + dot (λ, cons (θ))
443
+ return σ * _f (θ) + dot (λ, cons_oop (θ))
447
444
end
448
445
end
449
446
@@ -466,7 +463,8 @@ function OptimizationBase.instantiate_function(
466
463
end
467
464
468
465
if f. cons_vjp === nothing && cons_vjp == true && cons != = nothing
469
- extras_pullback = prepare_pullback (cons_oop, adtype, x)
466
+ extras_pullback = prepare_pullback (
467
+ cons_oop, adtype. dense_ad, x, ones (eltype (x), num_cons))
470
468
function cons_vjp! (J, θ, v)
471
469
pullback! (cons_oop, J, adtype. dense_ad, θ, v, extras_pullback)
472
470
end
@@ -477,7 +475,8 @@ function OptimizationBase.instantiate_function(
477
475
end
478
476
479
477
if f. cons_jvp === nothing && cons_jvp == true && cons != = nothing
480
- extras_pushforward = prepare_pushforward (cons_oop, adtype, x)
478
+ extras_pushforward = prepare_pushforward (
479
+ cons_oop, adtype. dense_ad, x, ones (eltype (x), length (x)))
481
480
function cons_jvp! (J, θ, v)
482
481
pushforward! (cons_oop, J, adtype. dense_ad, θ, v, extras_pushforward)
483
482
end
@@ -510,10 +509,11 @@ function OptimizationBase.instantiate_function(
510
509
end
511
510
512
511
lag_hess_prototype = f. lag_hess_prototype
513
- if cons != = nothing && cons_h == true && f. lag_h === nothing
512
+ lag_hess_colors = f. lag_hess_colorvec
513
+ if cons != = nothing && f. lag_h === nothing && lag_h == true
514
514
lag_extras = prepare_hessian (
515
515
lagrangian, soadtype, vcat (x, [one (eltype (x))], ones (eltype (x), num_cons)))
516
- lag_hess_prototype = lag_extras. coloring_result. S[1 : length (θ ), 1 : length (θ )]
516
+ lag_hess_prototype = lag_extras. coloring_result. S[1 : length (x ), 1 : length (x )]
517
517
lag_hess_colors = lag_extras. coloring_result. color
518
518
519
519
function lag_h! (H:: AbstractMatrix , θ, σ, λ)
@@ -587,14 +587,11 @@ end
587
587
588
588
function OptimizationBase. instantiate_function (
589
589
f:: OptimizationFunction{true} , cache:: OptimizationBase.ReInitCache ,
590
- adtype:: ADTypes.AutoSparse{<:AutoZygote} , num_cons = 0 ;
591
- g = false , h = false , hv = false , fg = false , fgh = false ,
592
- cons_j = false , cons_vjp = false , cons_jvp = false , cons_h = false )
590
+ adtype:: ADTypes.AutoSparse{<:AutoZygote} , num_cons = 0 ; kwargs... )
593
591
x = cache. u0
594
592
p = cache. p
595
593
596
- return OptimizationBase. instantiate_function (
597
- f, x, adtype, p, num_cons; g, h, hv, fg, fgh, cons_j, cons_vjp, cons_jvp, cons_h)
594
+ return OptimizationBase. instantiate_function (f, x, adtype, p, num_cons; kwargs... )
598
595
end
599
596
600
597
end
0 commit comments