Skip to content

Commit fa42939

Browse files
authored
Add default function for DI + option to use in data_domain gradients (#163)
1 parent 15492da commit fa42939

File tree

3 files changed

+55
-30
lines changed

3 files changed

+55
-30
lines changed

src/Jutul.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ module Jutul
4040
# AD
4141
import ForwardDiff
4242

43-
import DifferentiationInterface: AutoSparse, prepare_jacobian, jacobian
44-
import SparseConnectivityTracer: TracerLocalSparsityDetector
43+
import DifferentiationInterface: AutoSparse, prepare_jacobian, jacobian, AutoForwardDiff
44+
import SparseConnectivityTracer: TracerLocalSparsityDetector, GradientTracer, IndexSetGradientPattern
4545
import SparseMatrixColorings: GreedyColoringAlgorithm
4646

4747
# Timing

src/ad/AdjointsDI/adjoints.jl

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ end
9393

9494
function setup_adjoint_storage_generic(X, F, packed_steps::AdjointPackedResult, G;
9595
state0 = missing,
96-
backend = missing,
9796
do_prep = true,
9897
di_sparse = true,
98+
backend = Jutul.default_di_backend(sparse = di_sparse),
9999
info_level = 0,
100100
single_step_sparsity = true,
101101
use_sparsity = true
@@ -285,18 +285,7 @@ end
285285

286286
function setup_jacobian_evaluation!(storage, X, F, G, packed_steps, case0, backend, do_prep, single_step_sparsity, di_sparse)
287287
if ismissing(backend)
288-
if di_sparse
289-
gt = SparseConnectivityTracer.GradientTracer{SparseConnectivityTracer.IndexSetGradientPattern{Int, Set{Int}}}
290-
sparsity_detector = TracerLocalSparsityDetector(gradient_tracer_type=gt)
291-
# sparsity_detector = TracerLocalSparsityDetector()
292-
backend = AutoSparse(
293-
AutoForwardDiff();
294-
sparsity_detector = sparsity_detector,
295-
coloring_algorithm = GreedyColoringAlgorithm(),
296-
)
297-
else
298-
backend = AutoForwardDiff()
299-
end
288+
backend = Jutul.default_di_backend(sparse = di_sparse)
300289
end
301290

302291
H = AdjointObjectiveHelper(F, G, packed_steps)

src/variables/vectorization.jl

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -221,34 +221,70 @@ Compute the (sparse) Jacobian of parameters with respect to data_domain values
221221
(i.e. floating point values). Optionally, `config` can be passed to allow
222222
`vectorize_variables` to only include a subset of the parameters.
223223
"""
224-
function parameters_jacobian_wrt_data_domain(model; copy = true, config = nothing)
225-
if copy
226-
model = deepcopy(model)
227-
end
224+
function parameters_jacobian_wrt_data_domain(model;
225+
copy = true,
226+
config = nothing,
227+
use_di = true,
228+
backend = default_di_backend(),
229+
prep = missing
230+
)
228231
data_domain = model.data_domain
229232
x = vectorize_data_domain(data_domain)
230-
x_ad = ST.create_advec(x)
231-
devectorize_data_domain!(data_domain, x_ad)
232-
prm = setup_parameters(model, perform_copy = false)
233-
prm_flat = vectorize_variables(model, prm, :parameters, T = eltype(x_ad), config = config)
234-
n_parameters = length(prm_flat)
235-
n_data_domain_values = length(x)
236-
# This is Jacobian of parameters with respect to data_domain
237-
J = ST.jacobian(prm_flat, n_data_domain_values)
238-
@assert size(J) == (n_parameters, n_data_domain_values)
233+
if use_di
234+
if copy
235+
model = deepcopy(model)
236+
end
237+
function F(X)
238+
model_tmp = deepcopy(model)
239+
dd = model_tmp.data_domain
240+
devectorize_data_domain!(dd, X)
241+
prm = setup_parameters(model_tmp, perform_copy = false)
242+
return vectorize_variables(model_tmp, prm, :parameters, T = eltype(X), config = config)
243+
end
244+
if ismissing(prep)
245+
J = jacobian(F, backend, x)
246+
else
247+
J = jacobian(F, prep, backend, x)
248+
end
249+
else
250+
x_ad = ST.create_advec(x)
251+
devectorize_data_domain!(data_domain, x_ad)
252+
prm = setup_parameters(model, perform_copy = false)
253+
prm_flat = vectorize_variables(model, prm, :parameters, T = eltype(x_ad), config = config)
254+
n_parameters = length(prm_flat)
255+
n_data_domain_values = length(x)
256+
# This is Jacobian of parameters with respect to data_domain
257+
J = ST.jacobian(prm_flat, n_data_domain_values)
258+
@assert size(J) == (n_parameters, n_data_domain_values)
259+
end
239260
return J
240261
end
241262

263+
function default_di_backend(; sparse = true)
264+
if sparse
265+
gt = GradientTracer{IndexSetGradientPattern{Int, Set{Int}}}
266+
sparsity_detector = TracerLocalSparsityDetector(gradient_tracer_type=gt)
267+
backend = AutoSparse(
268+
AutoForwardDiff();
269+
sparsity_detector = sparsity_detector,
270+
coloring_algorithm = GreedyColoringAlgorithm(),
271+
)
272+
else
273+
backend = AutoForwardDiff()
274+
end
275+
return backend
276+
end
277+
242278
"""
243279
data_domain_to_parameters_gradient(model, parameter_gradient; dp_dd = missing, config = nothing)
244280
245281
Make a data_domain copy that contains the gradient of some objective with
246282
respect to the fields in the data_domain, assuming that the parameters were
247283
initialized directly from the data_domain via (`setup_parameters`)[@ref].
248284
"""
249-
function data_domain_to_parameters_gradient(model, parameter_gradient; dp_dd = missing, config = nothing)
285+
function data_domain_to_parameters_gradient(model, parameter_gradient; dp_dd = missing, config = nothing, kwarg...)
250286
if ismissing(dp_dd)
251-
dp_dd = parameters_jacobian_wrt_data_domain(model, copy = true, config = config)
287+
dp_dd = parameters_jacobian_wrt_data_domain(model; copy = true, config = config, kwarg...)
252288
end
253289
do_dp = vectorize_variables(model, parameter_gradient, :parameters, config = config)
254290
# do/dd = do/dp * dp/dd

0 commit comments

Comments
 (0)