Skip to content

Commit 8022d9f

Browse files
authored
Merge pull request #156 from thorek1/copilot/fix-20a9a78f-efdd-429a-8977-682bd45f4534
2 parents b1b5f04 + d88cc6e commit 8022d9f

File tree

1 file changed

+50
-4
lines changed

1 file changed

+50
-4
lines changed

src/get_functions.jl

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2725,11 +2725,11 @@ function get_moments(𝓂::ℳ;
27252725
@info "Most of the time is spent calculating derivatives wrt parameters. If they are not needed, add `derivatives = false` as an argument to the function call." maxlog = DEFAULT_MAXLOG
27262726
end
27272727

2728-
if (!variance && !standard_deviation && !non_stochastic_steady_state && !mean)
2728+
if (!variance && !standard_deviation && !non_stochastic_steady_state && !mean && !covariance)
27292729
derivatives = false
27302730
end
27312731

2732-
if parameter_derivatives != :all && (variance || standard_deviation || non_stochastic_steady_state || mean)
2732+
if parameter_derivatives != :all && (variance || standard_deviation || non_stochastic_steady_state || mean || covariance)
27332733
derivatives = true
27342734
end
27352735

@@ -2878,14 +2878,30 @@ function get_moments(𝓂::ℳ;
28782878

28792879

28802880
if covariance
2881+
axis3 = vcat(:Covariance, 𝓂.parameters[param_idx])
2882+
2883+
if any(x -> contains(string(x), ""), axis3)
2884+
axis3_decomposed = decompose_name.(axis3)
2885+
axis3 = [length(a) > 1 ? string(a[1]) * "{" * join(a[2],"}{") * "}" * (a[end] isa Symbol ? string(a[end]) : "") : string(a[1]) for a in axis3_decomposed]
2886+
end
2887+
28812888
if algorithm == :pruned_second_order
28822889
covar_dcmp, Σᶻ₂, state_μ, Δμˢ₂, autocorr_tmp, ŝ_to_ŝ₂, ŝ_to_y₂, Σʸ₁, Σᶻ₁, SS_and_pars, 𝐒₁, ∇₁, 𝐒₂, ∇₂, solved = calculate_second_order_moments_with_covariance(𝓂.parameter_values, 𝓂, opts = opts)
2890+
2891+
# Compute covariance derivatives
2892+
dcovariance = 𝒟.jacobian(x -> vec(calculate_second_order_moments_with_covariance(x, 𝓂, opts = opts)[1]), backend, 𝓂.parameter_values)[:,param_idx]
28832893
elseif algorithm == :pruned_third_order
28842894
covar_dcmp, state_μ, _, solved = calculate_third_order_moments(𝓂.parameter_values, :full_covar, 𝓂, opts = opts)
2895+
2896+
# Compute covariance derivatives
2897+
dcovariance = 𝒟.jacobian(x -> vec(calculate_third_order_moments(x, :full_covar, 𝓂, opts = opts)[1]), backend, 𝓂.parameter_values)[:,param_idx]
28852898
else
28862899
covar_dcmp, ___, __, _, solved = calculate_covariance(𝓂.parameter_values, 𝓂, opts = opts)
28872900

28882901
@assert solved "Could not find covariance matrix."
2902+
2903+
# Compute covariance derivatives
2904+
dcovariance = 𝒟.jacobian(x -> vec(calculate_covariance(x, 𝓂, opts = opts)[1]), backend, 𝓂.parameter_values)[:,param_idx]
28892905
end
28902906
end
28912907

@@ -3045,8 +3061,38 @@ function get_moments(𝓂::ℳ;
30453061
axis1 = [length(a) > 1 ? string(a[1]) * "{" * join(a[2],"}{") * "}" * (a[end] isa Symbol ? string(a[end]) : "") : string(a[1]) for a in axis1_decomposed]
30463062
end
30473063

3048-
# push!(ret,KeyedArray(covar_dcmp[var_idx, var_idx]; Variables = axis1, 𝑉𝑎𝑟𝑖𝑎𝑏𝑙𝑒𝑠 = axis1))
3049-
ret[:covariance] = KeyedArray(covar_dcmp[var_idx, var_idx]; Variables = axis1, 𝑉𝑎𝑟𝑖𝑎𝑏𝑙𝑒𝑠 = axis1)
3064+
if derivatives
3065+
# Reshape the flattened derivatives back to n x n x p structure
3066+
n_vars = length(var_idx)
3067+
n_params = length(param_idx)
3068+
3069+
# Create array to hold covariance and derivatives: n x n x (1 + p)
3070+
covar_with_derivs = zeros(n_vars, n_vars, 1 + n_params)
3071+
3072+
# First slice is the covariance matrix
3073+
covar_with_derivs[:, :, 1] = covar_dcmp[var_idx, var_idx]
3074+
3075+
# Subsequent slices are derivatives wrt each parameter
3076+
for i in 1:n_params
3077+
# dcovariance[:,i] is vectorized, need to reshape to n x n
3078+
covar_with_derivs[:, :, i+1] = reshape(dcovariance[:, i], n_vars, n_vars)
3079+
end
3080+
3081+
# Create axis names
3082+
if !@isdefined axis3
3083+
axis3 = vcat(:Covariance, 𝓂.parameters[param_idx])
3084+
3085+
if any(x -> contains(string(x), ""), axis3)
3086+
axis3_decomposed = decompose_name.(axis3)
3087+
axis3 = [length(a) > 1 ? string(a[1]) * "{" * join(a[2],"}{") * "}" * (a[end] isa Symbol ? string(a[end]) : "") : string(a[1]) for a in axis3_decomposed]
3088+
end
3089+
end
3090+
3091+
ret[:covariance] = KeyedArray(covar_with_derivs; Variables = axis1, 𝑉𝑎𝑟𝑖𝑎𝑏𝑙𝑒𝑠 = axis1, Covariance_and_∂covariance∂parameter = axis3)
3092+
else
3093+
# push!(ret,KeyedArray(covar_dcmp[var_idx, var_idx]; Variables = axis1, 𝑉𝑎𝑟𝑖𝑎𝑏𝑙𝑒𝑠 = axis1))
3094+
ret[:covariance] = KeyedArray(covar_dcmp[var_idx, var_idx]; Variables = axis1, 𝑉𝑎𝑟𝑖𝑎𝑏𝑙𝑒𝑠 = axis1)
3095+
end
30503096
end
30513097

30523098
return ret

0 commit comments

Comments
 (0)