Skip to content

Commit

Permalink
Implemented the composed basis functionallity. This allows us to mult…
Browse files Browse the repository at this point in the history
…iply, crossproduct or inner product a basis with a function or the normal vector.

All functionallity is screened away from the user.
Discussion needed about what can be made visible.
This commit will support the later implemented composed operators.
  • Loading branch information
PaulOlyslager committed Sep 19, 2024
1 parent c6e5853 commit 9fe6af4
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/BEAST.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ include("bases/stagedtimestep.jl")
include("bases/timebasis.jl")
include("bases/tensorbasis.jl")

include("bases/composedbasis.jl")
include("bases/local/localcomposedbasis.jl")

include("operator.jl")

include("quadrature/quadstrats.jl")
Expand Down
91 changes: 91 additions & 0 deletions src/bases/composedbasis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
struct FunctionWrapper{T}
func::Function
function FunctionWrapper(f::Function;evalpoint = @SVector [0.0,0.0,0.0])
new{typeof(f(evalpoint))}(f)
end
end

function (f::FunctionWrapper{T})(x)::T where {T}
return f.func(x)
end

function (f::FunctionWrapper{T})(x::CompScienceMeshes.MeshPointNM)::T where {T}
return f.func(cartesian(x))
end
scalartype(F::FunctionWrapper{T}) where {T} = eltype(T)
function scalartype(::NormalVector)
@warn "The scallartype of the NormalVector is set at Float32, if used in combination with Float64 basis or operator this is no problem, in the case of Float16 it is."
return Float32
end

abstract type _BasisOperations{T} <: Space{T} end

struct _BasisTimes{T} <: _BasisOperations{T}
el1
el2
function _BasisTimes(el1,el2)
new{promote_type(scalartype(el1),scalartype(el2))}(el1,el2)
end
function _BasisTimes{T}(el1,el2) where {T}
new{T}(el1,el2)
end
end

struct _BasisCross{T} <: _BasisOperations{T}
el1
el2
function _BasisCross(el1,el2)
new{promote_type(scalartype(el1),scalartype(el2))}(el1,el2)
end
function _BasisCross{T}(el1,el2) where {T}
new{T}(el1,el2)
end
end

struct _BasisDot{T} <: _BasisOperations{T}
el1
el2
function _BasisDot(el1,el2)
new{promote_type(scalartype(el1),scalartype(el2))}(el1,el2)
end
function _BasisDot{T}(el1,el2) where {T}
new{T}(el1,el2)
end
end

#### wrapping of the functions
_BasisTimes(a::Function,b::Function) = _BasisTimes(FunctionWrapper(a),FunctionWrapper(b))
_BasisTimes(a::Function,b) = _BasisTimes(FunctionWrapper(a),b)
_BasisTimes(a,b::Function) = _BasisTimes(a,FunctionWrapper(b))

_BasisCross(a::Function,b::Function) = _BasisCross(FunctionWrapper(a),FunctionWrapper(b))
_BasisCross(a::Function,b) = _BasisCross(FunctionWrapper(a),b)
_BasisCross(a,b::Function) = _BasisCross(a,FunctionWrapper(b))

_BasisDot(a::Function,b::Function) = _BasisDot(FunctionWrapper(a),FunctionWrapper(b))
_BasisDot(a::Function,b) = _BasisDot(FunctionWrapper(a),b)
_BasisDot(a,b::Function) = _BasisDot(a,FunctionWrapper(b))


refspace(a::_BasisTimes{T}) where {T} = _LocalBasisTimes(T,refspace(a.el1),refspace(a.el2))
refspace(a::_BasisCross{T}) where {T} = _LocalBasisCross(T,refspace(a.el1),refspace(a.el2))
refspace(a::_BasisDot{T}) where {T} = _LocalBasisDot(T,refspace(a.el1),refspace(a.el2))
refspace(a::Function) = FunctionWrapper(a)
refspace(a::FunctionWrapper) = a
refspace(a::NormalVector) = a

numfunctions(a::Union{NormalVector,FunctionWrapper,Function}) = missing
numfunctions(a::_BasisOperations) = coalesce(numfunctions(a.el1) , numfunctions(a.el2))

geometry(a::_BasisOperations) = coalesce(geometry(a.el1),geometry(a.el2))
basisfunction(a::_BasisOperations,i) = coalesce(basisfunction(a.el1,i),basisfunction(a.el2,i))
positions(a::_BasisOperations) = coalesce(positions(a.el1),positions(a.el2))

geometry(a::Union{NormalVector,FunctionWrapper,Function}) = missing
basisfunction(a::Union{NormalVector,FunctionWrapper,Function},i) = missing
positions(a::Union{NormalVector,FunctionWrapper,Function}) = missing


subset(a::T,I) where {T <: _BasisOperations} = T(subset(a.el1,I),subset(a.el2,I))
subset(a::Union{NormalVector,FunctionWrapper,Function},I) = a

53 changes: 53 additions & 0 deletions src/bases/local/localcomposedbasis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
abstract type _LocalBasisOperations{T} <: RefSpace{T,:None} end
numfunctions(a::_LocalBasisOperations) = coalesce(numfunctions(a.el1) , numfunctions(a.el2))

struct _LocalBasisTimes{T,U,V} <: _LocalBasisOperations{T}
el1::U
el2::V
function _LocalBasisTimes(::Type{T},el1::U,el2::V) where {T,U,V}
new{T,U,V}(el1,el2)
end
end

struct _LocalBasisCross{T,U,V} <: _LocalBasisOperations{T}
el1::U
el2::V
function _LocalBasisCross(::Type{T},el1::U,el2::V) where {T,U,V}
new{T,U,V}(el1,el2)
end
end

struct _LocalBasisDot{T,U,V} <: _LocalBasisOperations{T}
el1::U
el2::V
function _LocalBasisDot(::Type{T},el1::U,el2::V) where {T,U,V}
new{T,U,V}(el1,el2)
end
end

function (op::U where {U <: _LocalBasisOperations})(p)
l = op.el1(p)
r = op.el2(p)
return _execute_operation(l,r,op)
end

function (op::NormalVector)(x::CompScienceMeshes.MeshPointNM)
return normal(x)
end

operation(a::_LocalBasisTimes) = *
operation(a::_LocalBasisCross) = ×
operation(a::_LocalBasisDot) =

_execute_operation(el1::SVector{N,<:NamedTuple},el2::SVector,op::U) where {N,U} = SVector{N}(_execute_operation_named(el1.data,el2,operation(op)))
_execute_operation(el1::SVector{N,<:NamedTuple},el2::U,op::_LocalBasisTimes) where {N,U <: Number} = SVector{N}(_execute_operation_named(el1.data,el2,operation(op)))
_execute_operation(el1::SVector,el2::SVector{N,<:NamedTuple},op::U) where {N,U} = SVector{N}(_execute_operation_named(el1,el2.data,operation(op)))
_execute_operation(el1::U,el2::SVector{N,<:NamedTuple},op::_LocalBasisTimes) where {N,U <: Number} = SVector{N}(_execute_operation_named(el1,el2.data,operation(op)))
_execute_operation(el1::SVector,el2::SVector,op::U) where {U <: Union{_LocalBasisCross,_LocalBasisDot}} = operation(op)(el1,el2)
_execute_operation(el1::U,el2::V,op::_LocalBasisTimes) where {U <: Number,V <: Number}= el1*el2
_execute_operation(el1::SVector{N,<:NamedTuple},el2::SVector{M,<:NamedTuple},op) where {N,M} = @error "multiplication of basisses not (yet) supported"

_execute_operation_named(a::NTuple{N},b,op) where {N} = ((value=op(a[1].value,b),), _execute_operation_named(Base.tail(a),b,op)...)
_execute_operation_named(a::NTuple{1},b,op) = ((value=op(a[1].value,b),),)
_execute_operation_named(b,a::NTuple{N},op) where {N} = ((value=op(b,a[1].value),), _execute_operation_named(b,Base.tail(a),op)...)
_execute_operation_named(b,a::NTuple{1},op) = ((value=op(b,a[1].value),),)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ include("test_variational.jl")

include("test_handlers.jl")

include("test_composed_basis.jl")
using TestItemRunner
@run_package_tests

Expand Down
34 changes: 34 additions & 0 deletions test/test_composed_basis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
######## test multiplied basis
using CompScienceMeshes
using LinearAlgebra
using BEAST
using Test

Γ = meshcuboid(1.0,1.0,1.0,1.0)
X = raviartthomas(Γ)
f(x) = 2.0
Y = BEAST._BasisTimes(X,f)
K = Maxwell3D.doublelayer(wavenumber=1.0)
m1 = assemble(K,Y,Y)
m2 = assemble(K,X,X)
@test m1 4*m2

U = BEAST._BasisTimes(BEAST._BasisDot(x->(@SVector [1.0,1.0,1.0]),n),X)
m = assemble(K,U,U)

L = lagrangecxd0(Γ)
Z = BEAST._BasisTimes(L,n)
m3 = assemble(K,Z,Z)
@test norm(m3) 0.08420116178577139



##### go in code
# using StaticArrays
# s = simplex((@SVector [1.0,0.0,0.0]),(@SVector [0.0,1.0,0.0]),(@SVector [0.0,0.0,0.0]))
# p = neighborhood(s,[0.5,0.2])

# rtref = BEAST.RTRefSpace{Float64}()
# f(x)= 2.0

# mref = BEAST._LocalBasisTimes(Float64,BEAST.FunctionWrapper(f),rtref)

0 comments on commit 9fe6af4

Please sign in to comment.