Skip to content

Commit

Permalink
Merge pull request #134 from paresula:bugfix/assembly_BilForm
Browse files Browse the repository at this point in the history
Change assembly of BilForm
  • Loading branch information
krcools authored Jun 28, 2024
2 parents fada046 + 6c6d8fe commit 8b22ac9
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 3 deletions.
15 changes: 13 additions & 2 deletions src/solvers/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,17 @@ lift(a::ConvolutionOperators.AbstractConvOp ,I,J,U,V) =
function assemble(bf::BilForm, X::DirectProductSpace, Y::DirectProductSpace;
materialize=BEAST.assemble)

T = Int32
@assert !isempty(bf.terms)

spaceTimeBasis = isa(X.factors[1], BEAST.SpaceTimeBasis)

if spaceTimeBasis
p = [numstages(temporalbasis(ch)) for ch in X.factors]
lincombv = ConvolutionOperators.LiftedConvOp[]
else
p = 1
lincombv = LinearMap[]
end

M = numfunctions.(spatialbasis(X).factors) .* p
Expand All @@ -237,7 +240,7 @@ function assemble(bf::BilForm, X::DirectProductSpace, Y::DirectProductSpace;
U = BlockArrays.blockedrange(M)
V = BlockArrays.blockedrange(N)

sum(bf.terms) do term
for term in bf.terms

x = X.factors[term.test_id]
for op in reverse(term.test_ops)
Expand All @@ -251,7 +254,15 @@ function assemble(bf::BilForm, X::DirectProductSpace, Y::DirectProductSpace;

a = term.coeff * term.kernel
z = materialize(a, x, y)
lift(z, Block(term.test_id), Block(term.trial_id), U, V)

Smap = lift(z, Block(term.test_id), Block(term.trial_id), U, V)
T = promote_type(T, eltype(Smap))
push!(lincombv, Smap)
end
if spaceTimeBasis
return sum(lincombv)
else
return LinearMaps.LinearCombination{T}(lincombv)
end
end

Expand Down
18 changes: 17 additions & 1 deletion test/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.2"
manifest_format = "2.0"
project_hash = "571829245fb8bd97876fa91a4e7d398091f62379"
project_hash = "1b78102b70e82631772273acf6687b12f1d8fa74"

[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
Expand Down Expand Up @@ -299,6 +299,22 @@ git-tree-sha1 = "71e8ee0f9fe0e86a8f8c7f28361e5118eab2f93f"
uuid = "18c40d15-f7cd-5a6d-bc92-87468d86c5db"
version = "5.0.0+0"

[[deps.LinearMaps]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "ee79c3208e55786de58f8dcccca098ced79f743f"
uuid = "7a12625a-238d-50fd-b39a-03d52299707e"
version = "3.11.3"

[deps.LinearMaps.extensions]
LinearMapsChainRulesCoreExt = "ChainRulesCore"
LinearMapsSparseArraysExt = "SparseArrays"
LinearMapsStatisticsExt = "Statistics"

[deps.LinearMaps.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[deps.LogExpFunctions]]
deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "c3ce8e7420b3a6e071e0fe4745f5d4300e37b13f"
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ CompScienceMeshes = "3e66a162-7b8c-5da0-b8f8-124ecd2c3ae1"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SauterSchwabQuadrature = "535c7bfe-2023-5c1d-b712-654ef9d93a38"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
6 changes: 6 additions & 0 deletions test/test_directproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using CompScienceMeshes
using BEAST

using Test
import LinearMaps

for U in [Float32, Float64]
m1 = meshrectangle(U(1.0), U(1.0), U(0.5))
Expand All @@ -25,4 +26,9 @@ for U in [Float32, Float64]
t = assemble(T, X, X)

@test size(t) == (nt,nt)

bilterms = [BEAST.Variational.BilTerm(1,1,Any[],Any[],1,T)]

BilForm = BEAST.Variational.BilForm(:i, :j, bilterms)
@test typeof(assemble(BilForm, X, X)) == LinearMaps.LinearCombination{U, Vector{LinearMaps.LinearMap}}
end

0 comments on commit 8b22ac9

Please sign in to comment.