Skip to content

Commit

Permalink
Add ergodic variance to 2nd order solution (#141)
Browse files Browse the repository at this point in the history
* Added x_ergodic to the 2nd order
  • Loading branch information
jlperla authored Jun 22, 2022
1 parent fb196e7 commit d74c9c7
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 12 deletions.
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ jobs:
fail-fast: false
matrix:
version:
- "1.6"
- "1.7"
os:
- ubuntu-latest
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiableStateSpaceModels"
uuid = "beacd9db-9e5e-4956-9b09-459a4b2028df"
authors = ["Jesse Perla <jesseperla@gmail.com> and contributors"]
version = "0.4.18"
version = "0.4.19"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -25,7 +25,7 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[compat]
ChainRulesCore = "1"
DifferenceEquations = "0.4.16"
DifferenceEquations = "0.4.17"
DocStringExtensions = "0.8"
MultivariatePolynomials = "0.4.4"
GeneralizedSylvesterSolver = "0.1"
Expand All @@ -39,4 +39,4 @@ RecursiveFactorization = "0.2"
StructArrays = "0.6"
SymbolicUtils = "0.19.7"
Symbolics = "4"
julia = "1.6, 1.7"
julia = "1.7"
8 changes: 8 additions & 0 deletions src/generate_perturbation_derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,14 @@ function ChainRulesCore.rrule(::typeof(generate_perturbation), m::PerturbationMo
Δp[i] += dot(c.g_x_p[i], Δsol.g_x)
end
end
if (~isnothing(Δsol.x_ergodic_var))
if ((Δsol.x_ergodic_var != NoTangent()) &
(Δsol.x_ergodic_var != ZeroTangent()))
for i in 1:n_p_d
Δp[i] += dot(c.V_p[i], Δsol.x_ergodic_var)
end
end
end
if (~iszero(Δsol.C_1))
for i in 1:n_p_d
Δp[i] += dot(c.C_1_p[i], Δsol.C_1)
Expand Down
7 changes: 6 additions & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ struct SecondOrderPerturbationSolution{T1<:AbstractVector,T2<:AbstractVector,
T10,T11<:AbstractArray,T12,
T13<:AbstractVector,T14<:AbstractMatrix,
T15<:AbstractVector,
T16<:AbstractArray} <:
T16<:AbstractArray,T17<:AbstractMatrix} <:
AbstractSecondOrderPerturbationSolution
retcode::Symbol
x_symbols::Vector{Symbol}
Expand All @@ -481,6 +481,7 @@ struct SecondOrderPerturbationSolution{T1<:AbstractVector,T2<:AbstractVector,
D::T6
Q::T7 # can be nothing
η::T8
x_ergodic_var::T17
Γ::T9

g_xx::T10
Expand Down Expand Up @@ -514,6 +515,10 @@ function SecondOrderPerturbationSolution(retcode, m::PerturbationModel, c::Solve
make_covariance_matrix(c.Ω),
c.Q,
c.η,
(settings.calculate_ergodic_distribution == true) ?
c.V :
diagm(settings.singular_covariance_value *
ones(m.n_x)),
c.Γ,
c.g_xx,
c.g_σσ,
Expand Down
13 changes: 6 additions & 7 deletions test/second_order_gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,20 @@ function test_second_order(p_d, p_f, m)
sol = generate_perturbation(m, p_d, p_f, Val(2))
return sum(sol.y) + sum(sol.x) + sum(sol.A_0) + +sum(sol.A_1) + sum(sol.A_2) +
sum(sol.B) + sum(sol.C_0) + sum(sol.C_1) + sum(sol.C_2) + sum(sol.D) +
sum(sol.g_xx) + sum(sol.g_σσ) + sum(sol.g_x)
sum(sol.g_xx) + sum(sol.g_σσ) + sum(sol.g_x) + sum(sol.x_ergodic_var)
end

# Some sort of inference issues. Trouble putting in function and need the `const` for now
# see #117
#@testset "grad_tests" begin
const m_grad_2 = @include_example_module(Examples.rbc_observables) # const fixes current bug. Can't move inside
# @testset "grad_tests" begin
m_grad_2 = @include_example_module(Examples.rbc_observables) # const fixes current bug. Can't move inside
p_f == 0.2, δ = 0.02, σ = 0.01, Ω_1 = 0.01)
p_d == 0.5, β = 0.95)
test_second_order(p_d, p_f, m_grad_2)
gradient((args...) -> test_second_order(args..., p_f, m_grad), p_d)
@test test_second_order(p_d, p_f, m_grad_2) 88.3145253236005
gradient((args...) -> test_second_order(args..., p_f, m_grad_2), p_d)

test_rrule(Zygote.ZygoteRuleConfig(),
(args...) -> test_second_order(args..., p_f, m_grad_2), p_d;
rrule_f = rrule_via_ad,
check_inferred = false)

#end
# end
1 change: 1 addition & 0 deletions test/second_order_perturbation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ end
@test c.h_x sol.A_1
@test 0.5 * c.h_xx sol.A_2
@test c.Ω sqrt.(sol.D)
@test c.V sol.x_ergodic_var
end

@testset "Evaluate 2nd Order Derivatives into cache" begin
Expand Down

2 comments on commit d74c9c7

@jlperla
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/62891

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.19 -m "<description of version>" d74c9c71e8e891d6cb27923a8659cf93992cfd84
git push origin v0.4.19

Please sign in to comment.