@@ -2725,11 +2725,11 @@ function get_moments(𝓂::ℳ;
27252725 @info " Most of the time is spent calculating derivatives wrt parameters. If they are not needed, add `derivatives = false` as an argument to the function call." maxlog = DEFAULT_MAXLOG
27262726 end
27272727
2728- if (! variance && ! standard_deviation && ! non_stochastic_steady_state && ! mean)
2728+ if (! variance && ! standard_deviation && ! non_stochastic_steady_state && ! mean && ! covariance )
27292729 derivatives = false
27302730 end
27312731
2732- if parameter_derivatives != :all && (variance || standard_deviation || non_stochastic_steady_state || mean)
2732+ if parameter_derivatives != :all && (variance || standard_deviation || non_stochastic_steady_state || mean || covariance )
27332733 derivatives = true
27342734 end
27352735
@@ -2878,14 +2878,30 @@ function get_moments(𝓂::ℳ;
28782878
28792879
28802880 if covariance
2881+ axis3 = vcat (:Covariance , 𝓂. parameters[param_idx])
2882+
2883+ if any (x -> contains (string (x), " ◖" ), axis3)
2884+ axis3_decomposed = decompose_name .(axis3)
2885+ axis3 = [length (a) > 1 ? string (a[1 ]) * " {" * join (a[2 ]," }{" ) * " }" * (a[end ] isa Symbol ? string (a[end ]) : " " ) : string (a[1 ]) for a in axis3_decomposed]
2886+ end
2887+
28812888 if algorithm == :pruned_second_order
28822889 covar_dcmp, Σᶻ₂, state_μ, Δμˢ₂, autocorr_tmp, ŝ_to_ŝ₂, ŝ_to_y₂, Σʸ₁, Σᶻ₁, SS_and_pars, 𝐒₁, ∇₁, 𝐒₂, ∇₂, solved = calculate_second_order_moments_with_covariance (𝓂. parameter_values, 𝓂, opts = opts)
2890+
2891+ # Compute covariance derivatives
2892+ dcovariance = 𝒟. jacobian (x -> vec (calculate_second_order_moments_with_covariance (x, 𝓂, opts = opts)[1 ]), backend, 𝓂. parameter_values)[:,param_idx]
28832893 elseif algorithm == :pruned_third_order
28842894 covar_dcmp, state_μ, _, solved = calculate_third_order_moments (𝓂. parameter_values, :full_covar , 𝓂, opts = opts)
2895+
2896+ # Compute covariance derivatives
2897+ dcovariance = 𝒟. jacobian (x -> vec (calculate_third_order_moments (x, :full_covar , 𝓂, opts = opts)[1 ]), backend, 𝓂. parameter_values)[:,param_idx]
28852898 else
28862899 covar_dcmp, ___, __, _, solved = calculate_covariance (𝓂. parameter_values, 𝓂, opts = opts)
28872900
28882901 @assert solved " Could not find covariance matrix."
2902+
2903+ # Compute covariance derivatives
2904+ dcovariance = 𝒟. jacobian (x -> vec (calculate_covariance (x, 𝓂, opts = opts)[1 ]), backend, 𝓂. parameter_values)[:,param_idx]
28892905 end
28902906 end
28912907
@@ -3045,8 +3061,38 @@ function get_moments(𝓂::ℳ;
30453061 axis1 = [length (a) > 1 ? string (a[1 ]) * " {" * join (a[2 ]," }{" ) * " }" * (a[end ] isa Symbol ? string (a[end ]) : " " ) : string (a[1 ]) for a in axis1_decomposed]
30463062 end
30473063
3048- # push!(ret,KeyedArray(covar_dcmp[var_idx, var_idx]; Variables = axis1, 𝑉𝑎𝑟𝑖𝑎𝑏𝑙𝑒𝑠 = axis1))
3049- ret[:covariance ] = KeyedArray (covar_dcmp[var_idx, var_idx]; Variables = axis1, 𝑉𝑎𝑟𝑖𝑎𝑏𝑙𝑒𝑠 = axis1)
3064+ if derivatives
3065+ # Reshape the flattened derivatives back to n x n x p structure
3066+ n_vars = length (var_idx)
3067+ n_params = length (param_idx)
3068+
3069+ # Create array to hold covariance and derivatives: n x n x (1 + p)
3070+ covar_with_derivs = zeros (n_vars, n_vars, 1 + n_params)
3071+
3072+ # First slice is the covariance matrix
3073+ covar_with_derivs[:, :, 1 ] = covar_dcmp[var_idx, var_idx]
3074+
3075+ # Subsequent slices are derivatives wrt each parameter
3076+ for i in 1 : n_params
3077+ # dcovariance[:,i] is vectorized, need to reshape to n x n
3078+ covar_with_derivs[:, :, i+ 1 ] = reshape (dcovariance[:, i], n_vars, n_vars)
3079+ end
3080+
3081+ # Create axis names
3082+ if ! @isdefined axis3
3083+ axis3 = vcat (:Covariance , 𝓂. parameters[param_idx])
3084+
3085+ if any (x -> contains (string (x), " ◖" ), axis3)
3086+ axis3_decomposed = decompose_name .(axis3)
3087+ axis3 = [length (a) > 1 ? string (a[1 ]) * " {" * join (a[2 ]," }{" ) * " }" * (a[end ] isa Symbol ? string (a[end ]) : " " ) : string (a[1 ]) for a in axis3_decomposed]
3088+ end
3089+ end
3090+
3091+ ret[:covariance ] = KeyedArray (covar_with_derivs; Variables = axis1, 𝑉𝑎𝑟𝑖𝑎𝑏𝑙𝑒𝑠 = axis1, Covariance_and_∂covariance∂parameter = axis3)
3092+ else
3093+ # push!(ret,KeyedArray(covar_dcmp[var_idx, var_idx]; Variables = axis1, 𝑉𝑎𝑟𝑖𝑎𝑏𝑙𝑒𝑠 = axis1))
3094+ ret[:covariance ] = KeyedArray (covar_dcmp[var_idx, var_idx]; Variables = axis1, 𝑉𝑎𝑟𝑖𝑎𝑏𝑙𝑒𝑠 = axis1)
3095+ end
30503096 end
30513097
30523098 return ret
0 commit comments