From 9fe6af40b526de245993b6e4be13ca144d768796 Mon Sep 17 00:00:00 2001 From: PaulOlyslager Date: Thu, 19 Sep 2024 15:04:49 +0200 Subject: [PATCH] Implemented the composed basis functionallity. This allows us to multiply, crossproduct or inner product a basis with a function or the normal vector. All functionallity is screened away from the user. Discussion needed about what can be made visible. This commit will support the later implemented composed operators. --- src/BEAST.jl | 3 + src/bases/composedbasis.jl | 91 +++++++++++++++++++++++++++ src/bases/local/localcomposedbasis.jl | 53 ++++++++++++++++ test/runtests.jl | 1 + test/test_composed_basis.jl | 34 ++++++++++ 5 files changed, 182 insertions(+) create mode 100644 src/bases/composedbasis.jl create mode 100644 src/bases/local/localcomposedbasis.jl create mode 100644 test/test_composed_basis.jl diff --git a/src/BEAST.jl b/src/BEAST.jl index 14daed8a..783ce83a 100644 --- a/src/BEAST.jl +++ b/src/BEAST.jl @@ -182,6 +182,9 @@ include("bases/stagedtimestep.jl") include("bases/timebasis.jl") include("bases/tensorbasis.jl") +include("bases/composedbasis.jl") +include("bases/local/localcomposedbasis.jl") + include("operator.jl") include("quadrature/quadstrats.jl") diff --git a/src/bases/composedbasis.jl b/src/bases/composedbasis.jl new file mode 100644 index 00000000..9b767d99 --- /dev/null +++ b/src/bases/composedbasis.jl @@ -0,0 +1,91 @@ +struct FunctionWrapper{T} + func::Function + function FunctionWrapper(f::Function;evalpoint = @SVector [0.0,0.0,0.0]) + new{typeof(f(evalpoint))}(f) + end +end + +function (f::FunctionWrapper{T})(x)::T where {T} + return f.func(x) +end + +function (f::FunctionWrapper{T})(x::CompScienceMeshes.MeshPointNM)::T where {T} + return f.func(cartesian(x)) +end +scalartype(F::FunctionWrapper{T}) where {T} = eltype(T) +function scalartype(::NormalVector) + @warn "The scallartype of the NormalVector is set at Float32, if used in combination with Float64 basis or operator this is no problem, in the case of Float16 it is." + return Float32 +end + +abstract type _BasisOperations{T} <: Space{T} end + +struct _BasisTimes{T} <: _BasisOperations{T} + el1 + el2 + function _BasisTimes(el1,el2) + new{promote_type(scalartype(el1),scalartype(el2))}(el1,el2) + end + function _BasisTimes{T}(el1,el2) where {T} + new{T}(el1,el2) + end +end + +struct _BasisCross{T} <: _BasisOperations{T} + el1 + el2 + function _BasisCross(el1,el2) + new{promote_type(scalartype(el1),scalartype(el2))}(el1,el2) + end + function _BasisCross{T}(el1,el2) where {T} + new{T}(el1,el2) + end +end + +struct _BasisDot{T} <: _BasisOperations{T} + el1 + el2 + function _BasisDot(el1,el2) + new{promote_type(scalartype(el1),scalartype(el2))}(el1,el2) + end + function _BasisDot{T}(el1,el2) where {T} + new{T}(el1,el2) + end +end + +#### wrapping of the functions +_BasisTimes(a::Function,b::Function) = _BasisTimes(FunctionWrapper(a),FunctionWrapper(b)) +_BasisTimes(a::Function,b) = _BasisTimes(FunctionWrapper(a),b) +_BasisTimes(a,b::Function) = _BasisTimes(a,FunctionWrapper(b)) + +_BasisCross(a::Function,b::Function) = _BasisCross(FunctionWrapper(a),FunctionWrapper(b)) +_BasisCross(a::Function,b) = _BasisCross(FunctionWrapper(a),b) +_BasisCross(a,b::Function) = _BasisCross(a,FunctionWrapper(b)) + +_BasisDot(a::Function,b::Function) = _BasisDot(FunctionWrapper(a),FunctionWrapper(b)) +_BasisDot(a::Function,b) = _BasisDot(FunctionWrapper(a),b) +_BasisDot(a,b::Function) = _BasisDot(a,FunctionWrapper(b)) + + +refspace(a::_BasisTimes{T}) where {T} = _LocalBasisTimes(T,refspace(a.el1),refspace(a.el2)) +refspace(a::_BasisCross{T}) where {T} = _LocalBasisCross(T,refspace(a.el1),refspace(a.el2)) +refspace(a::_BasisDot{T}) where {T} = _LocalBasisDot(T,refspace(a.el1),refspace(a.el2)) +refspace(a::Function) = FunctionWrapper(a) +refspace(a::FunctionWrapper) = a +refspace(a::NormalVector) = a + +numfunctions(a::Union{NormalVector,FunctionWrapper,Function}) = missing +numfunctions(a::_BasisOperations) = coalesce(numfunctions(a.el1) , numfunctions(a.el2)) + +geometry(a::_BasisOperations) = coalesce(geometry(a.el1),geometry(a.el2)) +basisfunction(a::_BasisOperations,i) = coalesce(basisfunction(a.el1,i),basisfunction(a.el2,i)) +positions(a::_BasisOperations) = coalesce(positions(a.el1),positions(a.el2)) + +geometry(a::Union{NormalVector,FunctionWrapper,Function}) = missing +basisfunction(a::Union{NormalVector,FunctionWrapper,Function},i) = missing +positions(a::Union{NormalVector,FunctionWrapper,Function}) = missing + + +subset(a::T,I) where {T <: _BasisOperations} = T(subset(a.el1,I),subset(a.el2,I)) +subset(a::Union{NormalVector,FunctionWrapper,Function},I) = a + diff --git a/src/bases/local/localcomposedbasis.jl b/src/bases/local/localcomposedbasis.jl new file mode 100644 index 00000000..7b169bec --- /dev/null +++ b/src/bases/local/localcomposedbasis.jl @@ -0,0 +1,53 @@ +abstract type _LocalBasisOperations{T} <: RefSpace{T,:None} end +numfunctions(a::_LocalBasisOperations) = coalesce(numfunctions(a.el1) , numfunctions(a.el2)) + +struct _LocalBasisTimes{T,U,V} <: _LocalBasisOperations{T} + el1::U + el2::V + function _LocalBasisTimes(::Type{T},el1::U,el2::V) where {T,U,V} + new{T,U,V}(el1,el2) + end +end + +struct _LocalBasisCross{T,U,V} <: _LocalBasisOperations{T} + el1::U + el2::V + function _LocalBasisCross(::Type{T},el1::U,el2::V) where {T,U,V} + new{T,U,V}(el1,el2) + end +end + +struct _LocalBasisDot{T,U,V} <: _LocalBasisOperations{T} + el1::U + el2::V + function _LocalBasisDot(::Type{T},el1::U,el2::V) where {T,U,V} + new{T,U,V}(el1,el2) + end +end + +function (op::U where {U <: _LocalBasisOperations})(p) + l = op.el1(p) + r = op.el2(p) + return _execute_operation(l,r,op) +end + +function (op::NormalVector)(x::CompScienceMeshes.MeshPointNM) + return normal(x) +end + +operation(a::_LocalBasisTimes) = * +operation(a::_LocalBasisCross) = × +operation(a::_LocalBasisDot) = ⋅ + +_execute_operation(el1::SVector{N,<:NamedTuple},el2::SVector,op::U) where {N,U} = SVector{N}(_execute_operation_named(el1.data,el2,operation(op))) +_execute_operation(el1::SVector{N,<:NamedTuple},el2::U,op::_LocalBasisTimes) where {N,U <: Number} = SVector{N}(_execute_operation_named(el1.data,el2,operation(op))) +_execute_operation(el1::SVector,el2::SVector{N,<:NamedTuple},op::U) where {N,U} = SVector{N}(_execute_operation_named(el1,el2.data,operation(op))) +_execute_operation(el1::U,el2::SVector{N,<:NamedTuple},op::_LocalBasisTimes) where {N,U <: Number} = SVector{N}(_execute_operation_named(el1,el2.data,operation(op))) +_execute_operation(el1::SVector,el2::SVector,op::U) where {U <: Union{_LocalBasisCross,_LocalBasisDot}} = operation(op)(el1,el2) +_execute_operation(el1::U,el2::V,op::_LocalBasisTimes) where {U <: Number,V <: Number}= el1*el2 +_execute_operation(el1::SVector{N,<:NamedTuple},el2::SVector{M,<:NamedTuple},op) where {N,M} = @error "multiplication of basisses not (yet) supported" + +_execute_operation_named(a::NTuple{N},b,op) where {N} = ((value=op(a[1].value,b),), _execute_operation_named(Base.tail(a),b,op)...) +_execute_operation_named(a::NTuple{1},b,op) = ((value=op(a[1].value,b),),) +_execute_operation_named(b,a::NTuple{N},op) where {N} = ((value=op(b,a[1].value),), _execute_operation_named(b,Base.tail(a),op)...) +_execute_operation_named(b,a::NTuple{1},op) = ((value=op(b,a[1].value),),) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 956d5153..888b814e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -78,6 +78,7 @@ include("test_variational.jl") include("test_handlers.jl") +include("test_composed_basis.jl") using TestItemRunner @run_package_tests diff --git a/test/test_composed_basis.jl b/test/test_composed_basis.jl new file mode 100644 index 00000000..078bf870 --- /dev/null +++ b/test/test_composed_basis.jl @@ -0,0 +1,34 @@ +######## test multiplied basis +using CompScienceMeshes +using LinearAlgebra +using BEAST +using Test + +Γ = meshcuboid(1.0,1.0,1.0,1.0) +X = raviartthomas(Γ) +f(x) = 2.0 +Y = BEAST._BasisTimes(X,f) +K = Maxwell3D.doublelayer(wavenumber=1.0) +m1 = assemble(K,Y,Y) +m2 = assemble(K,X,X) +@test m1 ≈ 4*m2 + +U = BEAST._BasisTimes(BEAST._BasisDot(x->(@SVector [1.0,1.0,1.0]),n),X) +m = assemble(K,U,U) + +L = lagrangecxd0(Γ) +Z = BEAST._BasisTimes(L,n) +m3 = assemble(K,Z,Z) +@test norm(m3) ≈ 0.08420116178577139 + + + +##### go in code +# using StaticArrays +# s = simplex((@SVector [1.0,0.0,0.0]),(@SVector [0.0,1.0,0.0]),(@SVector [0.0,0.0,0.0])) +# p = neighborhood(s,[0.5,0.2]) + +# rtref = BEAST.RTRefSpace{Float64}() +# f(x)= 2.0 + +# mref = BEAST._LocalBasisTimes(Float64,BEAST.FunctionWrapper(f),rtref)