Skip to content

Commit

Permalink
Storage type (#23)
Browse files Browse the repository at this point in the history
Co-authored-by: Tamas Hakkel <tamas.hakkel@mediso.com>
  • Loading branch information
hakkelt and Tamas Hakkel authored Apr 11, 2024
1 parent 4dea8fb commit 7413876
Show file tree
Hide file tree
Showing 46 changed files with 468 additions and 437 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AbstractOperators"
uuid = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c"
version = "0.3"
version = "0.4"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
4 changes: 2 additions & 2 deletions src/calculus/AdjointOperator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ export AdjointOperator
"""
`AdjointOperator(A::AbstractOperator)`
Shorthand constructor:
Shorthand constructor:
`'(A::AbstractOperator)`
Expand All @@ -19,7 +19,7 @@ julia> [DFT(10); DCT(10)]'
"""
struct AdjointOperator{T <: AbstractOperator} <: AbstractOperator
A::T
function AdjointOperator(A::T) where {T<:AbstractOperator}
function AdjointOperator(A::T) where {T<:AbstractOperator}
is_linear(A) == false && error("Cannot transpose a nonlinear operator. You might use `jacobian`")
new{T}(A)
end
Expand Down
18 changes: 9 additions & 9 deletions src/calculus/AffineAdd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ export AffineAdd
"""
`AffineAdd(A::AbstractOperator, d, [sign = true])`
Affine addition to `AbstractOperator` with an array or scalar `d`.
Affine addition to `AbstractOperator` with an array or scalar `d`.
Use `sign = false` to perform subtraction.
Expand All @@ -26,17 +26,17 @@ true
struct AffineAdd{L <: AbstractOperator, D <: Union{AbstractArray, Number}, S} <: AbstractOperator
A::L
d::D
function AffineAdd(A::L, d::D, sign::Bool = true) where {L, D <: AbstractArray}
if size(d) != size(A,1)
function AffineAdd(A::L, d::D, sign::Bool = true) where {L, D <: AbstractArray}
if size(d) != size(A,1)
throw(DimensionMismatch("codomain size of $A not compatible with array `d` of size $(size(d))"))
end
if eltype(d) != codomainType(A)
if eltype(d) != codomainType(A)
error("cannot tilt opertor having codomain type $(codomainType(A)) with array of type $(eltype(d))")
end
new{L,D,sign}(A,d)
end
# scalar
function AffineAdd(A::L, d::D, sign::Bool = true) where {L, D <: Number}
function AffineAdd(A::L, d::D, sign::Bool = true) where {L, D <: Number}
if typeof(d) <: Complex && codomainType(A) <: Real
error("cannot tilt opertor having codomain type $(codomainType(A)) with array of type $(eltype(d))")
end
Expand All @@ -46,12 +46,12 @@ end

# Mappings
# array
function mul!(y::DD, T::AffineAdd{L, D, true}, x) where {L <: AbstractOperator, DD, D}
function mul!(y::DD, T::AffineAdd{L, D, true}, x) where {L <: AbstractOperator, DD, D}
mul!(y,T.A,x)
y .+= T.d
end

function mul!(y::DD, T::AffineAdd{L, D, false}, x) where {L <: AbstractOperator, DD, D}
function mul!(y::DD, T::AffineAdd{L, D, false}, x) where {L <: AbstractOperator, DD, D}
mul!(y,T.A,x)
y .-= T.d
end
Expand All @@ -70,7 +70,7 @@ is_null(L::AffineAdd) = is_null(L.A)
is_eye(L::AffineAdd) = is_diagonal(L.A)
is_diagonal(L::AffineAdd) = is_diagonal(L.A)
is_invertible(L::AffineAdd) = is_invertible(L.A)
is_AcA_diagonal(L::AffineAdd) = is_AcA_diagonal(L.A)
is_AcA_diagonal(L::AffineAdd) = is_AcA_diagonal(L.A)
is_AAc_diagonal(L::AffineAdd) = is_AAc_diagonal(L.A)
is_full_row_rank(L::AffineAdd) = is_full_row_rank(L.A)
is_full_column_rank(L::AffineAdd) = is_full_column_rank(L.A)
Expand All @@ -90,7 +90,7 @@ sign(T::AffineAdd{L,D, true}) where {L,D} = 1

function permute(T::AffineAdd{L,D,S}, p::AbstractVector{Int}) where {L,D,S}
A = permute(T.A,p)
return AffineAdd(A,T.d,S)
return AffineAdd(A,T.d,S)
end

displacement(A::AffineAdd{L,D,true}) where {L,D} = A.d .+ displacement(A.A)
Expand Down
19 changes: 8 additions & 11 deletions src/calculus/Ax_mul_Bx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,10 @@ end

# Constructors
function Ax_mul_Bx(A::AbstractOperator,B::AbstractOperator)
s,t = size(A,1), codomainType(A)
bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
s,t = size(B,1), codomainType(B)
bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
s,t = size(A,2), domainType(A)
bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
bufA = allocateInCodomain(A)
bufB = allocateInCodomain(B)
bufC = allocateInCodomain(B)
bufD = allocateInDomain(A)
Ax_mul_Bx(A,B,bufA,bufB,bufC,bufD)
end

Expand Down Expand Up @@ -95,16 +92,16 @@ end

size(P::Union{Ax_mul_Bx,Ax_mul_BxJac}) = ((size(P.A,1)[1],size(P.B,1)[2]),size(P.A,2))

fun_name(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = fun_name(L.A)*"*"*fun_name(L.B)
fun_name(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = fun_name(L.A)*"*"*fun_name(L.B)

domainType(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = domainType(L.A)
codomainType(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = codomainType(L.A)

# utils
function permute(P::Ax_mul_Bx{L1,L2,C,D},
function permute(P::Ax_mul_Bx{L1,L2,C,D},
p::AbstractVector{Int}) where {L1,L2,C,D <:ArrayPartition}
Ax_mul_Bx(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]) )
Ax_mul_Bx(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]))
end

remove_displacement(P::Ax_mul_Bx) =
remove_displacement(P::Ax_mul_Bx) =
Ax_mul_Bx(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufC, P.bufD)
21 changes: 9 additions & 12 deletions src/calculus/Ax_mul_Bxt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ struct Ax_mul_Bxt{
bufD::D
function Ax_mul_Bxt(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1,L2,C,D}
if ndims(A,1) == 1
if size(A) != size(B)
if size(A) != size(B)
throw(DimensionMismatch("Cannot compose operators"))
end
elseif ndims(A,1) == 2 && ndims(B,1) == 2 && size(A,2) == size(B,2)
if size(A,1)[2] != size(B,1)[2]
if size(A,1)[2] != size(B,1)[2]
throw(DimensionMismatch("Cannot compose operators"))
end
else
Expand All @@ -68,13 +68,10 @@ end

# Constructors
function Ax_mul_Bxt(A::AbstractOperator,B::AbstractOperator)
s,t = size(A,1), codomainType(A)
bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
s,t = size(B,1), codomainType(B)
bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
s,t = size(A,2), domainType(A)
bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
bufA = allocateInCodomain(A)
bufB = allocateInCodomain(B)
bufC = allocateInCodomain(A)
bufD = allocateInDomain(A)
Ax_mul_Bxt(A,B,bufA,bufB,bufC,bufD)
end

Expand Down Expand Up @@ -103,16 +100,16 @@ end

size(P::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = ((size(P.A,1)[1],size(P.B,1)[1]),size(P.A,2))

fun_name(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = fun_name(L.A)*"*"*fun_name(L.B)
fun_name(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = fun_name(L.A)*"*"*fun_name(L.B)

domainType(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = domainType(L.A)
codomainType(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = codomainType(L.A)

# utils
function permute(P::Ax_mul_Bxt{L1,L2,C,D},
function permute(P::Ax_mul_Bxt{L1,L2,C,D},
p::AbstractVector{Int}) where {L1,L2,C,D <:ArrayPartition}
Ax_mul_Bxt(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]) )
end

remove_displacement(P::Ax_mul_Bxt) =
remove_displacement(P::Ax_mul_Bxt) =
Ax_mul_Bxt(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufC, P.bufD)
21 changes: 9 additions & 12 deletions src/calculus/Axt_mul_Bx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ struct Axt_mul_Bx{N,
bufD::D
function Axt_mul_Bx(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1,L2,C,D}
if ndims(A,1) == 1
if size(A) != size(B)
if size(A) != size(B)
throw(DimensionMismatch("Cannot compose operators"))
end
elseif ndims(A,1) == 2 && ndims(B,1) == 2 && size(A,2) == size(B,2)
if size(A,1)[1] != size(B,1)[1]
if size(A,1)[1] != size(B,1)[1]
throw(DimensionMismatch("Cannot compose operators"))
end
else
Expand All @@ -69,13 +69,10 @@ end

# Constructors
function Axt_mul_Bx(A::AbstractOperator,B::AbstractOperator)
s,t = size(A,1), codomainType(A)
bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
s,t = size(B,1), codomainType(B)
bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
s,t = size(A,2), domainType(A)
bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
bufA = allocateInCodomain(A)
bufB = allocateInCodomain(B)
bufC = allocateInCodomain(A)
bufD = allocateInDomain(A)
Axt_mul_Bx(A,B,bufA,bufB,bufC,bufD)
end

Expand Down Expand Up @@ -122,16 +119,16 @@ end
size(P::Union{Axt_mul_Bx{1},Axt_mul_BxJac{1}}) = ((1,),size(P.A,2))
size(P::Union{Axt_mul_Bx{2},Axt_mul_BxJac{2}}) = ((size(P.A,1)[2],size(P.B,1)[2]),size(P.A,2))

fun_name(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = fun_name(L.A)*"*"*fun_name(L.B)
fun_name(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = fun_name(L.A)*"*"*fun_name(L.B)

domainType(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = domainType(L.A)
codomainType(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = codomainType(L.A)

# utils
function permute(P::Axt_mul_Bx{N,L1,L2,C,D},
function permute(P::Axt_mul_Bx{N,L1,L2,C,D},
p::AbstractVector{Int}) where {N,L1,L2,C,D <:ArrayPartition}
Axt_mul_Bx(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]) )
end

remove_displacement(P::Axt_mul_Bx) =
remove_displacement(P::Axt_mul_Bx) =
Axt_mul_Bx(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufC, P.bufD)
19 changes: 9 additions & 10 deletions src/calculus/BroadCast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ julia> B*[1.;2.]
```
"""
struct BroadCast{N,
L <: AbstractOperator,
T <: AbstractArray,
struct BroadCast{N,
L <: AbstractOperator,
T <: AbstractArray,
D <: AbstractArray,
M,
C <: NTuple{M,Colon},
Expand All @@ -36,14 +36,14 @@ struct BroadCast{N,
cols::C
idxs::I

function BroadCast(A::L,dim_out::NTuple{N,Int},bufC::T, bufD::D) where {N,
L<:AbstractOperator,
function BroadCast(A::L,dim_out::NTuple{N,Int},bufC::T, bufD::D) where {N,
L<:AbstractOperator,
T<:AbstractArray,
D<:AbstractArray
}
Base.Broadcast.check_broadcast_shape(dim_out,size(A,1))
if size(A,1) != (1,)
M = length(size(A,1))
M = length(size(A,1))
cols = ([Colon() for i = 1:M]...,)
idxs = CartesianIndices((dim_out[M+1:end]...,))
new{N,L,T,D,M,typeof(cols),typeof(idxs)}(A,dim_out,bufC,bufD,cols,idxs)
Expand All @@ -52,14 +52,13 @@ struct BroadCast{N,
idxs = CartesianIndices((1,))
new{N,L,T,D,M,NTuple{0,Colon},typeof(idxs)}(A,dim_out,bufC,bufD,(),idxs)
end

end
end

# Constructors

BroadCast(A::L, dim_out::NTuple{N,Int}) where {N,L<:AbstractOperator} =
BroadCast(A, dim_out, zeros(codomainType(A),size(A,1)), zeros(domainType(A),size(A,2)) )
BroadCast(A, dim_out, allocateInCodomain(A), allocateInDomain(A))

# Mappings

Expand All @@ -82,7 +81,7 @@ end
function mul!(y::CC, A::AdjointOperator{BroadCast{N,L,T,D,0,C,I}}, b::DD) where {N,L,T,D,C,I,CC,DD}
R = A.A
fill!(y, 0.)
bii = zeros(eltype(b),1)
bii = allocateInCodomain(R.A)
for bi in b
bii[1] = bi
mul!(R.bufD, R.A', bii)
Expand All @@ -92,7 +91,7 @@ function mul!(y::CC, A::AdjointOperator{BroadCast{N,L,T,D,0,C,I}}, b::DD) where
end

#TODO make this more general
#length(dim_out) == size(A,1) e.g. a .= b; size(a) = (m,n) size(b) = (1,n) matrix out, column in
#length(dim_out) == size(A,1) e.g. a .= b; size(a) = (m,n) size(b) = (1,n) matrix out, column in
function mul!(y::CC, A::AdjointOperator{BroadCast{2,L,T,D,2,C,I}}, b::DD) where {L,T,D,C,I,CC,DD}
R = A.A
fill!(y, 0.)
Expand Down
17 changes: 10 additions & 7 deletions src/calculus/Compose.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ export Compose
"""
`Compose(A::AbstractOperator,B::AbstractOperator)`
Shorthand constructor:
Shorthand constructor:
`A*B`
`A*B`
Compose different `AbstractOperator`s. Notice that the domain and codomain of the operators `A` and `B` must match, i.e. `size(A,2) == size(B,1)` and `domainType(A) == codomainType(B)`.
Expand All @@ -28,19 +28,22 @@ end

function Compose(L1::AbstractOperator, L2::AbstractOperator)
if size(L1,2) != size(L2,1)
throw(DimensionMismatch("cannot compose operators"))
throw(DimensionMismatch("cannot compose operators with different domain and codomain sizes"))
end
if domainType(L1) != codomainType(L2)
throw(DomainError())
throw(DomainError((domainType(L1),codomainType(L2)), "cannot compose operators with different domain and codomain types"))
end
Compose( L1, L2, Array{domainType(L1)}(undef,size(L2,1)) )
if domainStorageType(L1) != codomainStorageType(L2)
throw(DomainError((domainStorageType(L1),codomainStorageType(L2)), "cannot compose operators with different input and output storage types"))
end
Compose(L1, L2, allocateInCodomain(L2))
end

Compose(L1::AbstractOperator,L2::AbstractOperator,buf::AbstractArray) =
Compose( (L2,L1), (buf,))
Compose((L2,L1), (buf,))

Compose(L1::Compose, L2::AbstractOperator,buf::AbstractArray) =
Compose( (L2,L1.A...), (buf,L1.buf...))
Compose((L2,L1.A...), (buf,L1.buf...))

Compose(L1::AbstractOperator,L2::Compose, buf::AbstractArray) =
Compose((L2.A...,L1), (L2.buf...,buf))
Expand Down
Loading

0 comments on commit 7413876

Please sign in to comment.