Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SchurComplementW for ClimaTimesteppers #1364

Merged
merged 1 commit into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 105 additions & 82 deletions examples/hybrid/schur_complement_W.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using ClimaCore.Utilities: half
const compose = Operators.ComposeStencils()
const apply = Operators.ApplyStencil()

struct SchurComplementW{F, FT, J1, J2, J3, J4, S}
struct SchurComplementW{F, FT, J1, J2, J3, J4, S, T}
# whether this struct is used to compute Wfact_t or Wfact
transform::Bool

Expand All @@ -28,6 +28,10 @@ struct SchurComplementW{F, FT, J1, J2, J3, J4, S}

# whether to test the Jacobian and linear solver
test::Bool

# cache that is used to evaluate ldiv!
temp1::T
temp2::T
end

function SchurComplementW(Y, transform, flags, test = false)
Expand Down Expand Up @@ -61,6 +65,7 @@ function SchurComplementW(Y, transform, flags, test = false)
typeof(∂ᶠ𝕄ₜ∂ᶜρ),
typeof(∂ᶠ𝕄ₜ∂ᶠ𝕄),
typeof(S),
typeof(Y),
}(
transform,
flags,
Expand All @@ -72,6 +77,8 @@ function SchurComplementW(Y, transform, flags, test = false)
∂ᶠ𝕄ₜ∂ᶠ𝕄,
S,
test,
similar(Y),
similar(Y),
)
end

Expand Down Expand Up @@ -101,91 +108,107 @@ Finally, use (1) and (2) to get x1 and x2.
Note: The matrix S = A31 A13 + A32 A23 + A33 - I is the "Schur complement" of
[-I 0; 0 -I] (the top-left 4 blocks) in A.
=#
function linsolve!(::Type{Val{:init}}, f, u0; kwargs...)
function _linsolve!(x, A, b, update_matrix = false; kwargs...)
(; dtγ_ref, ∂ᶜρₜ∂ᶠ𝕄, ∂ᶜ𝔼ₜ∂ᶠ𝕄, ∂ᶠ𝕄ₜ∂ᶜ𝔼, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶠ𝕄ₜ∂ᶠ𝕄) = A
(; S) = A
dtγ = dtγ_ref[]

xᶜρ = x.c.ρ
bᶜρ = b.c.ρ
if :ρθ in propertynames(x.c)
xᶜ𝔼 = x.c.ρθ
bᶜ𝔼 = b.c.ρθ
elseif :ρe in propertynames(x.c)
xᶜ𝔼 = x.c.ρe
bᶜ𝔼 = b.c.ρe
elseif :ρe_int in propertynames(x.c)
xᶜ𝔼 = x.c.ρe_int
bᶜ𝔼 = b.c.ρe_int
end
if :ρw in propertynames(x.f)
xᶠ𝕄 = x.f.ρw.components.data.:1
bᶠ𝕄 = b.f.ρw.components.data.:1
elseif :w in propertynames(x.f)
xᶠ𝕄 = x.f.w.components.data.:1
bᶠ𝕄 = b.f.w.components.data.:1
end
# Function required by OrdinaryDiffEq.jl
linsolve!(::Type{Val{:init}}, f, u0; kwargs...) = _linsolve!
_linsolve!(x, A, b, update_matrix = false; kwargs...) =
LinearAlgebra.ldiv!(x, A, b)

# Function required by Krylov.jl (x and b can be AbstractVectors)
# See https://github.com/JuliaSmoothOptimizers/Krylov.jl/issues/605 for a
# related issue that requires the same workaround.
function LinearAlgebra.ldiv!(x, A::SchurComplementW, b)
A.temp1 .= b
LinearAlgebra.ldiv!(A.temp2, A, A.temp1)
x .= A.temp2
end

# TODO: Extend LinearAlgebra.I to work with stencil fields.
FT = eltype(eltype(S))
I = Ref(Operators.StencilCoefs{-1, 1}((zero(FT), one(FT), zero(FT))))
if Operators.bandwidths(eltype(∂ᶜ𝔼ₜ∂ᶠ𝕄)) != (-half, half)
str = "The linear solver cannot yet be run with the given ∂ᶜ𝔼ₜ/∂ᶠ𝕄 \
block, since it has more than 2 diagonals. So, ∂ᶜ𝔼ₜ/∂ᶠ𝕄 will \
be set to 0 for the Schur complement computation. Consider \
changing the ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode or the energy variable."
@warn str maxlog = 1
@. S = dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄) + dtγ * ∂ᶠ𝕄ₜ∂ᶠ𝕄 - I
else
@. S =
dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄) +
dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜ𝔼, ∂ᶜ𝔼ₜ∂ᶠ𝕄) +
dtγ * ∂ᶠ𝕄ₜ∂ᶠ𝕄 - I
end
function LinearAlgebra.ldiv!(
x::Fields.FieldVector,
A::SchurComplementW,
b::Fields.FieldVector,
)
(; dtγ_ref, ∂ᶜρₜ∂ᶠ𝕄, ∂ᶜ𝔼ₜ∂ᶠ𝕄, ∂ᶠ𝕄ₜ∂ᶜ𝔼, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶠ𝕄ₜ∂ᶠ𝕄) = A
(; S) = A
dtγ = dtγ_ref[]

xᶜρ = x.c.ρ
bᶜρ = b.c.ρ
if :ρθ in propertynames(x.c)
xᶜ𝔼 = x.c.ρθ
bᶜ𝔼 = b.c.ρθ
elseif :ρe in propertynames(x.c)
xᶜ𝔼 = x.c.ρe
bᶜ𝔼 = b.c.ρe
elseif :ρe_int in propertynames(x.c)
xᶜ𝔼 = x.c.ρe_int
bᶜ𝔼 = b.c.ρe_int
end
if :ρw in propertynames(x.f)
xᶠ𝕄 = x.f.ρw.components.data.:1
bᶠ𝕄 = b.f.ρw.components.data.:1
elseif :w in propertynames(x.f)
xᶠ𝕄 = x.f.w.components.data.:1
bᶠ𝕄 = b.f.w.components.data.:1
end

@. xᶠ𝕄 = bᶠ𝕄 + dtγ * (apply(∂ᶠ𝕄ₜ∂ᶜρ, bᶜρ) + apply(∂ᶠ𝕄ₜ∂ᶜ𝔼, bᶜ𝔼))

Operators.column_thomas_solve!(S, xᶠ𝕄)

@. xᶜρ = -bᶜρ + dtγ * apply(∂ᶜρₜ∂ᶠ𝕄, xᶠ𝕄)
@. xᶜ𝔼 = -bᶜ𝔼 + dtγ * apply(∂ᶜ𝔼ₜ∂ᶠ𝕄, xᶠ𝕄)

if A.test && Operators.bandwidths(eltype(∂ᶜ𝔼ₜ∂ᶠ𝕄)) == (-half, half)
Ni, Nj, _, Nv, Nh = size(Spaces.local_geometry_data(axes(xᶜρ)))
∂Yₜ∂Y = Array{FT}(undef, 3 * Nv + 1, 3 * Nv + 1)
ΔY = Array{FT}(undef, 3 * Nv + 1)
ΔΔY = Array{FT}(undef, 3 * Nv + 1)
for h in 1:Nh, j in 1:Nj, i in 1:Ni
∂Yₜ∂Y .= zero(FT)
∂Yₜ∂Y[1:Nv, (2 * Nv + 1):(3 * Nv + 1)] .=
matrix_column(∂ᶜρₜ∂ᶠ𝕄, axes(x.f), i, j, h)
∂Yₜ∂Y[(Nv + 1):(2 * Nv), (2 * Nv + 1):(3 * Nv + 1)] .=
matrix_column(∂ᶜ𝔼ₜ∂ᶠ𝕄, axes(x.f), i, j, h)
∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), 1:Nv] .=
matrix_column(∂ᶠ𝕄ₜ∂ᶜρ, axes(x.c), i, j, h)
∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), (Nv + 1):(2 * Nv)] .=
matrix_column(∂ᶠ𝕄ₜ∂ᶜ𝔼, axes(x.c), i, j, h)
∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), (2 * Nv + 1):(3 * Nv + 1)] .=
matrix_column(∂ᶠ𝕄ₜ∂ᶠ𝕄, axes(x.f), i, j, h)
ΔY[1:Nv] .= vector_column(xᶜρ, i, j, h)
ΔY[(Nv + 1):(2 * Nv)] .= vector_column(xᶜ𝔼, i, j, h)
ΔY[(2 * Nv + 1):(3 * Nv + 1)] .= vector_column(xᶠ𝕄, i, j, h)
ΔΔY[1:Nv] .= vector_column(bᶜρ, i, j, h)
ΔΔY[(Nv + 1):(2 * Nv)] .= vector_column(bᶜ𝔼, i, j, h)
ΔΔY[(2 * Nv + 1):(3 * Nv + 1)] .= vector_column(bᶠ𝕄, i, j, h)
@assert (-LinearAlgebra.I + dtγ * ∂Yₜ∂Y) * ΔY ≈ ΔΔY
end
end
# TODO: Extend LinearAlgebra.I to work with stencil fields.
FT = eltype(eltype(S))
I = Ref(Operators.StencilCoefs{-1, 1}((zero(FT), one(FT), zero(FT))))
if Operators.bandwidths(eltype(∂ᶜ𝔼ₜ∂ᶠ𝕄)) != (-half, half)
str = "The linear solver cannot yet be run with the given ∂ᶜ𝔼ₜ/∂ᶠ𝕄 \
block, since it has more than 2 diagonals. So, ∂ᶜ𝔼ₜ/∂ᶠ𝕄 will \
be set to 0 for the Schur complement computation. Consider \
changing the ∂ᶜ𝔼ₜ∂ᶠ𝕄_mode or the energy variable."
@warn str maxlog = 1
@. S = dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄) + dtγ * ∂ᶠ𝕄ₜ∂ᶠ𝕄 - I
else
@. S =
dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄) +
dtγ^2 * compose(∂ᶠ𝕄ₜ∂ᶜ𝔼, ∂ᶜ𝔼ₜ∂ᶠ𝕄) +
dtγ * ∂ᶠ𝕄ₜ∂ᶠ𝕄 - I
end

if :ρuₕ in propertynames(x.c)
@. x.c.ρuₕ = -b.c.ρuₕ
elseif :uₕ in propertynames(x.c)
@. x.c.uₕ = -b.c.uₕ
@. xᶠ𝕄 = bᶠ𝕄 + dtγ * (apply(∂ᶠ𝕄ₜ∂ᶜρ, bᶜρ) + apply(∂ᶠ𝕄ₜ∂ᶜ𝔼, bᶜ𝔼))

Operators.column_thomas_solve!(S, xᶠ𝕄)

@. xᶜρ = -bᶜρ + dtγ * apply(∂ᶜρₜ∂ᶠ𝕄, xᶠ𝕄)
@. xᶜ𝔼 = -bᶜ𝔼 + dtγ * apply(∂ᶜ𝔼ₜ∂ᶠ𝕄, xᶠ𝕄)

if A.test && Operators.bandwidths(eltype(∂ᶜ𝔼ₜ∂ᶠ𝕄)) == (-half, half)
Ni, Nj, _, Nv, Nh = size(Spaces.local_geometry_data(axes(xᶜρ)))
∂Yₜ∂Y = Array{FT}(undef, 3 * Nv + 1, 3 * Nv + 1)
ΔY = Array{FT}(undef, 3 * Nv + 1)
ΔΔY = Array{FT}(undef, 3 * Nv + 1)
for h in 1:Nh, j in 1:Nj, i in 1:Ni
∂Yₜ∂Y .= zero(FT)
∂Yₜ∂Y[1:Nv, (2 * Nv + 1):(3 * Nv + 1)] .=
matrix_column(∂ᶜρₜ∂ᶠ𝕄, axes(x.f), i, j, h)
∂Yₜ∂Y[(Nv + 1):(2 * Nv), (2 * Nv + 1):(3 * Nv + 1)] .=
matrix_column(∂ᶜ𝔼ₜ∂ᶠ𝕄, axes(x.f), i, j, h)
∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), 1:Nv] .=
matrix_column(∂ᶠ𝕄ₜ∂ᶜρ, axes(x.c), i, j, h)
∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), (Nv + 1):(2 * Nv)] .=
matrix_column(∂ᶠ𝕄ₜ∂ᶜ𝔼, axes(x.c), i, j, h)
∂Yₜ∂Y[(2 * Nv + 1):(3 * Nv + 1), (2 * Nv + 1):(3 * Nv + 1)] .=
matrix_column(∂ᶠ𝕄ₜ∂ᶠ𝕄, axes(x.f), i, j, h)
ΔY[1:Nv] .= vector_column(xᶜρ, i, j, h)
ΔY[(Nv + 1):(2 * Nv)] .= vector_column(xᶜ𝔼, i, j, h)
ΔY[(2 * Nv + 1):(3 * Nv + 1)] .= vector_column(xᶠ𝕄, i, j, h)
ΔΔY[1:Nv] .= vector_column(bᶜρ, i, j, h)
ΔΔY[(Nv + 1):(2 * Nv)] .= vector_column(bᶜ𝔼, i, j, h)
ΔΔY[(2 * Nv + 1):(3 * Nv + 1)] .= vector_column(bᶠ𝕄, i, j, h)
@assert (-LinearAlgebra.I + dtγ * ∂Yₜ∂Y) * ΔY ≈ ΔΔY
end
end

if A.transform
x .*= dtγ
end
if :ρuₕ in propertynames(x.c)
@. x.c.ρuₕ = -b.c.ρuₕ
elseif :uₕ in propertynames(x.c)
@. x.c.uₕ = -b.c.uₕ
end

if A.transform
x .*= dtγ
end
end
8 changes: 4 additions & 4 deletions test/Operators/finitedifference/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(center_space)
=#
face_space = Spaces.FaceFiniteDifferenceSpace(center_space)

function _linsolve!(x, A, b, update_matrix = false; kwargs...)
function test_linsolve!(x, A, b, update_matrix = false; kwargs...)

FT = Spaces.undertype(axes(x.c))

Expand Down Expand Up @@ -88,11 +88,11 @@ W = SchurComplementW(Y, use_transform, jacobi_flags)

using JET
using Test
@time _linsolve!(Y, W, b)
@time _linsolve!(Y, W, b)
@time test_linsolve!(Y, W, b)
@time test_linsolve!(Y, W, b)

@testset "JET test for `apply` in linsolve! kernel" begin
@test_opt _linsolve!(Y, W, b)
@test_opt test_linsolve!(Y, W, b)
end

ClimaCore.Operators.allow_mismatched_fd_spaces() = false
Loading