Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ docs/site/
# environment.
Manifest.toml

.vscode
.vscode
script/
77 changes: 70 additions & 7 deletions src/deim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,45 @@ function deim_interpolation_indices(basis::AbstractMatrix)::Vector{Int}
return indices
end

"""
$(TYPEDSIGNATURES)

Compute the QDEIM interpolation indices for the given projection basis.
"""
function qdeim_interpolation_indices(basis::AbstractMatrix)::Vector{Int}
dim = size(basis, 2)
return qr(basis', ColumnNorm()).p[1:dim]
end

"""
$(TYPEDSIGNATURES)

Compute the ODEIM interpolation indices for the given projection basis.
"""
function odeim_interpolation_indices(basis::AbstractMatrix, m::Int)::Vector{Int}
dim = size(basis, 2)
@assert m >= dim && m <= size(basis,1) "Invalid sampling dimension"

# Compute the first dim points with QDEIM
p = qdeim_interpolation_indices(basis)

# select points n+1, ..., m
for _ in (length(p) + 1):m
_, S, W = svd(basis[p, :])
gap = S[end - 1]^2 - S[end]^2 # eigengap
proj_basis = W' * basis'
r = gap .+ sum(proj_basis.^2, dims=1)
r -= sqrt.((gap .+ sum(proj_basis.^2, dims=1)).^2 - 4 * gap * (proj_basis[end, :].^2)')
indices = sortperm(vec(r), rev=true)
e = 1
while any(indices[e] .== p)
e += 1
end
push!(p, indices[e])
end
return p
end

"""
$(SIGNATURES)

Expand All @@ -51,6 +90,10 @@ where ``P=[\\mathbf e_{\\rho_1},\\dots,\\mathbf e_{\\rho_m}]\\in\\mathbb R^{n\\t
algorithm, and ``\\mathbf e_{\\rho_i}=[0,\\ldots,0,1,0,\\ldots,0]^T\\in\\mathbb R^n`` is
the ``\\rho_i``-th column of the identity matrix ``I_n\\in\\mathbb R^{n\\times n}``.

Besides the standard DEIM algorithm for interpolation, this method also supports the QDEIM
and the ODEIM algorithms. The ODEIM algorithm requires an additional parameter `odeim_dim`
to specify the number of the oversampled interpolation points.

# Arguments
- `full_vars::AbstractVector`: the dependent variables ``\\underset{n\\times 1}{\\mathbf y}`` in FOM.
- `linear_coeffs::AbstractMatrix`: the coefficient matrix ``\\underset{n\\times n}A`` of linear terms in FOM.
Expand All @@ -59,15 +102,22 @@ the ``\\rho_i``-th column of the identity matrix ``I_n\\in\\mathbb R^{n\\times n
- `reduced_vars::AbstractVector`: the dependent variables ``\\underset{k\\times 1}{\\hat{\\mathbf y}}`` in the reduced-order model.
- `linear_projection_matrix::AbstractMatrix`: the projection matrix ``\\underset{n\\times k}V`` for the dependent variables ``\\mathbf y``.
- `nonlinear_projection_matrix::AbstractMatrix`: the projection matrix ``\\underset{n\\times m}U`` for the nonlinear functions ``\\mathbf F``.
- `interpolation_algo::Symbol`: the interpolation algorithm, which can be `:deim`, `:qdeim`, or `:odeim`.

# Return
- `reduced_rhss`: the right-hand side of ROM.
- `linear_projection_eqs`: the linear projection mapping ``\\mathbf y=V\\hat{\\mathbf y}``.

# References
- [DEIM](https://epubs.siam.org/doi/abs/10.1137/110822724): Chaturantabut and Sorensen, 2012.
- [QDEIM](http://epubs.siam.org/doi/10.1137/15M1019271): Drmac and Gugercin, 2016.
- [ODEIM](https://epubs.siam.org/doi/10.1137/19M1307391): Peherstorfer, Drmac, and Gugercin, 2020.
"""
function deim(full_vars::AbstractVector, linear_coeffs::AbstractMatrix,
constant_part::AbstractVector, nonlinear_part::AbstractVector,
reduced_vars::AbstractVector, linear_projection_matrix::AbstractMatrix,
nonlinear_projection_matrix::AbstractMatrix; kwargs...)
nonlinear_projection_matrix::AbstractMatrix,
interpolation_algo::Symbol, odeim_dim::Integer; kwargs...)
# rename variables for convenience
y = full_vars
A = linear_coeffs
Expand All @@ -81,7 +131,13 @@ function deim(full_vars::AbstractVector, linear_coeffs::AbstractMatrix,
linear_projection_eqs = Symbolics.scalarize(y .~ V * ŷ)
linear_projection_dict = Dict(eq.lhs => eq.rhs for eq in linear_projection_eqs)

indices = deim_interpolation_indices(U) # DEIM interpolation indices
if interpolation_algo == :deim
indices = deim_interpolation_indices(U) # DEIM interpolation indices
elseif interpolation_algo == :qdeim
indices = qdeim_interpolation_indices(U) # QDEIM interpolation indices
elseif interpolation_algo == :odeim
indices = odeim_interpolation_indices(U, odeim_dim) # ODEIM interpolation indices
end
# the DEIM projector (not DEIM basis) satisfies
# F(original_vars) ≈ projector * F(pod_basis * reduced_vars)[indices]
projector = ((@view U[indices, :])' \ (U' * V))'
Expand All @@ -91,15 +147,17 @@ function deim(full_vars::AbstractVector, linear_coeffs::AbstractMatrix,
 = V' * A * V
ĝ = V' * g
reduced_rhss = Â * ŷ + ĝ + F̂
reduced_rhss, linear_projection_eqs
return reduced_rhss, linear_projection_eqs
end

"""
$(FUNCTIONNAME)(
sys::ModelingToolkit.ODESystem,
snapshot::AbstractMatrix,
pod_dim::Integer;
deim_dim::Integer = pod_dim,
name::Symbol = Symbol(nameof(sys), :_deim),
interpolation_algo::Symbol = :deim,
kwargs...
) -> ModelingToolkit.ODESystem

Expand All @@ -116,11 +174,16 @@ The LHS of equations in `sys` are all assumed to be 1st order derivatives. Use

The POD basis used for DEIM interpolation is obtained from the snapshot matrix of the
nonlinear terms, which is computed by executing the runtime-generated function for
nonlinear expressions.
nonlinear expressions.

Additional to the DEIM algorithm, this function also supports the QDEIM and ODEIM. For ODEIM,
the `odeim_dim` parameter specifies the number of oversampled interpolation points.
"""
function deim(sys::ODESystem, snapshot::AbstractMatrix, pod_dim::Integer;
deim_dim::Integer = pod_dim, name::Symbol = Symbol(nameof(sys), :_deim),
kwargs...)::ODESystem
deim_dim::Integer = pod_dim, odeim_dim::Integer = 2*pod_dim,
name::Symbol = Symbol(nameof(sys), :_deim),
interpolation_algo::Symbol = :deim, kwargs...)::ODESystem
@assert interpolation_algo ∈ (:deim, :qdeim, :odeim) "Invalid interpolation algorithm"
sys = deepcopy(sys)
@set! sys.name = name

Expand Down Expand Up @@ -158,7 +221,7 @@ function deim(sys::ODESystem, snapshot::AbstractMatrix, pod_dim::Integer;
reduce!(deim_reducer, TSVD())
U = deim_reducer.rbasis # DEIM projection basis

reduced_rhss, linear_projection_eqs = deim(dvs, A, g, F, ŷ, V, U; kwargs...)
reduced_rhss, linear_projection_eqs = deim(dvs, A, g, F, ŷ, V, U, interpolation_algo, odeim_dim; kwargs...)

reduced_deqs = D.(ŷ) ~ reduced_rhss
@set! sys.eqs = [Symbolics.scalarize(reduced_deqs); eqs]
Expand Down
30 changes: 26 additions & 4 deletions test/deim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,40 @@ sol = solve(ode_prob, Tsit5(), saveat = 1.0)

snapshot_simpsys = Array(sol.original_sol)
pod_dim = 3

# test DEIM
deim_sys = @test_nowarn deim(simp_sys, snapshot_simpsys, pod_dim)
# test QDEIM
qdeim_sys = @test_nowarn deim(simp_sys, snapshot_simpsys, pod_dim; interpolation_algo=:qdeim)
# test ODEIM
odeim_sys = @test_nowarn deim(simp_sys, snapshot_simpsys, pod_dim; interpolation_algo=:odeim)

# check the number of dependent variables in the new system
# DEIM: check the number of dependent variables in the new system
@test length(ModelingToolkit.get_states(deim_sys)) == pod_dim

deim_prob = ODEProblem(deim_sys, nothing, tspan)

deim_sol = solve(deim_prob, Tsit5(), saveat = 1.0)

nₓ = length(sol[x])
nₜ = length(sol[t])

# test solution retrieva
# test solution retrieval
@test size(deim_sol[v(x, t)]) == (nₓ, nₜ)
@test size(deim_sol[w(x, t)]) == (nₓ, nₜ)

# QDEIM: check the number of dependent variables in the new system
@test length(ModelingToolkit.get_states(qdeim_sys)) == pod_dim
deim_prob = ODEProblem(qdeim_sys, nothing, tspan)
deim_sol = solve(deim_prob, Tsit5(), saveat = 1.0)

# test solution retrieval
@test size(deim_sol[v(x, t)]) == (nₓ, nₜ)
@test size(deim_sol[w(x, t)]) == (nₓ, nₜ)

# ODEIM: check the number of dependent variables in the new system
@test length(ModelingToolkit.get_states(odeim_sys)) == pod_dim
deim_prob = ODEProblem(odeim_sys, nothing, tspan)
deim_sol = solve(deim_prob, Tsit5(), saveat = 1.0)

# test solution retrieval
@test size(deim_sol[v(x, t)]) == (nₓ, nₜ)
@test size(deim_sol[w(x, t)]) == (nₓ, nₜ)
Loading