Skip to content

Commit cfbfc40

Browse files
authored
Support BroadcastLayout multiplication (#68)
* support lazy broadcasting with mul * Update quasibroadcast.jl * avoid broadcasted * Remove QuasiIndices, unused code * simplifiable for D^2, use in it broadcast * Add reducedim (sum, etc.) * Update QuasiArrays.jl * move memorylayout _sum overloading * Update quasireducedim.jl * increase coverage * Require Julia v1.6 * Increase coverage * Test (x .* D) * y * (x .* D^2) * y
1 parent 6f845d5 commit cfbfc40

File tree

14 files changed

+465
-223
lines changed

14 files changed

+465
-223
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ jobs:
1010
fail-fast: false
1111
matrix:
1212
version:
13-
- '1.5'
14-
- '^1.6.0-0'
13+
- '1.6'
1514
os:
1615
- ubuntu-latest
1716
- macOS-latest

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "QuasiArrays"
22
uuid = "c4ea9172-b204-11e9-377d-29865faadc5c"
33
authors = ["Sheehan Olver <solver@mac.com>"]
4-
version = "0.5.2"
4+
version = "0.6"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
@@ -15,9 +15,9 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1515
ArrayLayouts = "0.7"
1616
DomainSets = "0.4"
1717
FillArrays = "0.11"
18-
LazyArrays = "0.21"
18+
LazyArrays = "0.21.5"
1919
StaticArrays = "1"
20-
julia = "1.5"
20+
julia = "1.6"
2121

2222
[extras]
2323
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

src/QuasiArrays.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import Base: Slice, IdentityUnitRange, ScalarIndex, RangeIndex, view, viewindexi
1414
check_parent_index_match, reindex, _isdisjoint, unsafe_indices, _unsafe_ind2sub,
1515
_ind2sub, _sub2ind, _ind2sub_recurse, _lookup, SubArray,
1616
parentindices, reverse, ndims, checkbounds, uncolon,
17-
promote_shape, maybeview, unsafe_view, checkindex, checkbounds_indices,
17+
maybeview, unsafe_view, checkindex, checkbounds_indices,
1818
throw_boundserror, rdims, replace_in_print_matrix, show, summary,
1919
hcat, vcat, hvcat
2020
import Base: *, /, \, +, -, ^, inv
@@ -26,7 +26,10 @@ import Base: exp, log, sqrt,
2626
import Base: Array, Matrix, Vector
2727
import Base: union, intersect, sort, sort!
2828
import Base: conj, real, imag
29-
import Base: sum, cumsum, diff
29+
# reducedim.jl imports
30+
import Base: prod, sum, cumsum, diff, add_sum, mul_prod, mapreduce, max, min, count, _count, any, _any, all, _all, _sum, _prod, _mapreduce, reduced_index, check_reducedims
31+
import Base: BitInteger, IEEEFloat, uniontypes, _InitialValue, safe_tail, reducedim1, _simple_count
32+
3033
import Base: ones, zeros, one, zero, fill
3134

3235
import Base.Broadcast: materialize, materialize!, BroadcastStyle, AbstractArrayStyle, Style, broadcasted, Broadcasted, Unknown,
@@ -45,7 +48,8 @@ import LazyArrays: MemoryLayout, UnknownLayout, Mul, ApplyLayout, BroadcastLayou
4548
rowsupport, colsupport, tuple_type_memorylayouts, applylayout, broadcastlayout,
4649
LdivStyle, most, InvLayout, PInvLayout, sub_materialize, lazymaterialize,
4750
_mul, rowsupport, DiagonalLayout, adjointlayout, transposelayout, conjlayout,
48-
sublayout, call, LazyArrayStyle, layout_getindex, _broadcast2broadcastarray, _applyarray_summary, _broadcastarray_summary
51+
sublayout, call, LazyArrayStyle, layout_getindex, _broadcast2broadcastarray, _applyarray_summary, _broadcastarray_summary,
52+
_broadcasted_mul, simplifiable, simplify
4953

5054
import Base.IteratorsMD
5155

@@ -86,6 +90,7 @@ include("quasireshapedarray.jl")
8690
include("quasipermutedims.jl")
8791
include("quasibroadcast.jl")
8892
include("abstractquasiarraymath.jl")
93+
include("quasireducedim.jl")
8994

9095
include("quasiarray.jl")
9196
include("quasiarraymath.jl")

src/abstractquasiarray.jl

Lines changed: 0 additions & 205 deletions
Original file line numberDiff line numberDiff line change
@@ -489,14 +489,6 @@ their component parts. A typical definition for an array that wraps a parent is
489489
dataids(A::AbstractQuasiArray) = (UInt(objectid(A)),)
490490

491491

492-
## structured matrix methods ##
493-
replace_in_print_matrix(A::AbstractQuasiMatrix,i,j,s::AbstractString) = s
494-
replace_in_print_matrix(A::AbstractQuasiVector,i,j,s::AbstractString) = s
495-
496-
## Concatenation ##
497-
eltypeof(x::AbstractQuasiArray) = eltype(x)
498-
499-
500492
## Reductions and accumulates ##
501493

502494
function isequal(A::AbstractQuasiArray, B::AbstractQuasiArray)
@@ -512,17 +504,6 @@ function isequal(A::AbstractQuasiArray, B::AbstractQuasiArray)
512504
return true
513505
end
514506

515-
function cmp(A::AbstractQuasiVector, B::AbstractQuasiVector)
516-
for (a, b) in zip(A, B)
517-
if !isequal(a, b)
518-
return isless(a, b) ? -1 : 1
519-
end
520-
end
521-
return cmp(length(A), length(B))
522-
end
523-
524-
isless(A::AbstractQuasiVector, B::AbstractQuasiVector) = cmp(A, B) < 0
525-
526507
function (==)(A::AbstractQuasiArray, B::AbstractQuasiArray)
527508
if axes(A) != axes(B)
528509
return false
@@ -539,192 +520,6 @@ function (==)(A::AbstractQuasiArray, B::AbstractQuasiArray)
539520
return anymissing ? missing : true
540521
end
541522

542-
_lookup(ind, r::Inclusion) = ind
543-
544-
_ind2sub(dims::NTuple{N,Any}, ind) where N = (@_inline_meta; _ind2sub_recurse(dims, ind-1))
545-
_ind2sub(inds::QuasiIndices, ind) = (@_inline_meta; _ind2sub_recurse(inds, ind-1))
546-
_ind2sub(inds::Tuple{Inclusion{<:Any},AbstractUnitRange{<:Integer}}, ind) = (@_inline_meta; _ind2sub_recurse(inds, ind-1))
547-
_ind2sub(inds::Tuple{AbstractUnitRange{<:Integer},Inclusion{<:Any}}, ind) = (@_inline_meta; _ind2sub_recurse(inds, ind-1))
548-
549-
function _ind2sub(inds::Union{NTuple{N,Any},QuasiIndices{N}}, ind::AbstractQuasiVector) where N
550-
M = length(ind)
551-
t = ntuple(n->similar(ind),Val(N))
552-
for (i,idx) in pairs(IndexLinear(), ind)
553-
sub = _ind2sub(inds, idx)
554-
for j = 1:N
555-
t[j][i] = sub[j]
556-
end
557-
end
558-
t
559-
end
560-
561-
562-
## map over arrays ##
563-
564-
## transform any set of dimensions
565-
## dims specifies which dimensions will be transformed. for example
566-
## dims==1:2 will call f on all slices A[:,:,...]
567-
568-
function mapslices(f, A::AbstractQuasiArray; dims)
569-
if isempty(dims)
570-
return map(f,A)
571-
end
572-
if !isa(dims, AbstractQuasiVector)
573-
dims = [dims...]
574-
end
575-
576-
dimsA = [axes(A)...]
577-
ndimsA = ndims(A)
578-
alldims = [1:ndimsA;]
579-
580-
otherdims = setdiff(alldims, dims)
581-
582-
idx = Any[first(ind) for ind in axes(A)]
583-
itershape = tuple(dimsA[otherdims]...)
584-
for d in dims
585-
idx[d] = Slice(axes(A, d))
586-
end
587-
588-
# Apply the function to the first slice in order to determine the next steps
589-
Aslice = A[idx...]
590-
r1 = f(Aslice)
591-
# In some cases, we can re-use the first slice for a dramatic performance
592-
# increase. The slice itself must be mutable and the result cannot contain
593-
# any mutable containers. The following errs on the side of being overly
594-
# strict (#18570 & #21123).
595-
safe_for_reuse = isa(Aslice, StridedArray) &&
596-
(isa(r1, Number) || (isa(r1, AbstractQuasiArray) && eltype(r1) <: Number))
597-
598-
# determine result size and allocate
599-
Rsize = copy(dimsA)
600-
# TODO: maybe support removing dimensions
601-
if !isa(r1, AbstractQuasiArray) || ndims(r1) == 0
602-
# If the result of f on a single slice is a scalar then we add singleton
603-
# dimensions. When adding the dimensions, we have to respect the
604-
# index type of the input array (e.g. in the case of OffsetArrays)
605-
tmp = similar(Aslice, typeof(r1), reduced_indices(Aslice, 1:ndims(Aslice)))
606-
tmp[firstindex(tmp)] = r1
607-
r1 = tmp
608-
end
609-
nextra = max(0, length(dims)-ndims(r1))
610-
if eltype(Rsize) == Int
611-
Rsize[dims] = [size(r1)..., ntuple(d->1, nextra)...]
612-
else
613-
Rsize[dims] = [axes(r1)..., ntuple(d->OneTo(1), nextra)...]
614-
end
615-
R = similar(r1, tuple(Rsize...,))
616-
617-
ridx = Any[map(first, axes(R))...]
618-
for d in dims
619-
ridx[d] = axes(R,d)
620-
end
621-
622-
concatenate_setindex!(R, r1, ridx...)
623-
624-
nidx = length(otherdims)
625-
indices = Iterators.drop(CartesianIndices(itershape), 1) # skip the first element, we already handled it
626-
inner_mapslices!(safe_for_reuse, indices, nidx, idx, otherdims, ridx, Aslice, A, f, R)
627-
end
628-
629-
concatenate_setindex!(R, X::AbstractQuasiArray, I...) = (R[I...] = X)
630-
631-
## 1 argument
632-
633-
function map!(f::F, dest::AbstractQuasiArray, A::AbstractQuasiArray) where F
634-
for (i,j) in zip(eachindex(dest),eachindex(A))
635-
dest[i] = f(A[j])
636-
end
637-
return dest
638-
end
639-
640-
# map on collections
641-
map(f, A::AbstractQuasiArray) = collect_similar(A, Generator(f,A))
642-
643-
## 2 argument
644-
function map!(f::F, dest::AbstractQuasiArray, A::AbstractQuasiArray, B::AbstractQuasiArray) where F
645-
for (i, j, k) in zip(eachindex(dest), eachindex(A), eachindex(B))
646-
dest[i] = f(A[j], B[k])
647-
end
648-
return dest
649-
end
650-
651-
652-
function map_n!(f::F, dest::AbstractQuasiArray, As) where F
653-
for i = LinearIndices(As[1])
654-
dest[i] = f(ith_all(i, As)...)
655-
end
656-
return dest
657-
end
658-
659-
map!(f::F, dest::AbstractQuasiArray, As::AbstractQuasiArray...) where {F} = map_n!(f, dest, As)
660-
661-
662-
## hashing AbstractQuasiArray ##
663-
664-
function hash(A::AbstractQuasiArray, h::UInt)
665-
h = hash(AbstractQuasiArray, h)
666-
# Axes are themselves AbstractQuasiArrays, so hashing them directly would stack overflow
667-
# Instead hash the tuple of firsts and lasts along each dimension
668-
h = hash(map(first, axes(A)), h)
669-
h = hash(map(last, axes(A)), h)
670-
isempty(A) && return h
671-
672-
# Goal: Hash approximately log(N) entries with a higher density of hashed elements
673-
# weighted towards the end and special consideration for repeated values. Colliding
674-
# hashes will often subsequently be compared by equality -- and equality between arrays
675-
# works elementwise forwards and is short-circuiting. This means that a collision
676-
# between arrays that differ by elements at the beginning is cheaper than one where the
677-
# difference is towards the end. Furthermore, blindly choosing log(N) entries from a
678-
# sparse array will likely only choose the same element repeatedly (zero in this case).
679-
680-
# To achieve this, we work backwards, starting by hashing the last element of the
681-
# array. After hashing each element, we skip `fibskip` elements, where `fibskip`
682-
# is pulled from the Fibonacci sequence -- Fibonacci was chosen as a simple
683-
# ~O(log(N)) algorithm that ensures we don't hit a common divisor of a dimension
684-
# and only end up hashing one slice of the array (as might happen with powers of
685-
# two). Finally, we find the next distinct value from the one we just hashed.
686-
687-
# This is a little tricky since skipping an integer number of values inherently works
688-
# with linear indices, but `findprev` uses `keys`. Hoist out the conversion "maps":
689-
ks = keys(A)
690-
key_to_linear = LinearIndices(ks) # Index into this map to compute the linear index
691-
linear_to_key = vec(ks) # And vice-versa
692-
693-
# Start at the last index
694-
keyidx = last(ks)
695-
linidx = key_to_linear[keyidx]
696-
fibskip = prevfibskip = oneunit(linidx)
697-
n = 0
698-
while true
699-
n += 1
700-
# Hash the current key-index and its element
701-
elt = A[keyidx]
702-
h = hash(keyidx=>elt, h)
703-
704-
# Skip backwards a Fibonacci number of indices -- this is a linear index operation
705-
linidx = key_to_linear[keyidx]
706-
linidx <= fibskip && break
707-
linidx -= fibskip
708-
keyidx = linear_to_key[linidx]
709-
710-
# Only increase the Fibonacci skip once every N iterations. This was chosen
711-
# to be big enough that all elements of small arrays get hashed while
712-
# obscenely large arrays are still tractable. With a choice of N=4096, an
713-
# entirely-distinct 8000-element array will have ~75% of its elements hashed,
714-
# with every other element hashed in the first half of the array. At the same
715-
# time, hashing a `typemax(Int64)`-length Float64 range takes about a second.
716-
if rem(n, 4096) == 0
717-
fibskip, prevfibskip = fibskip + prevfibskip, fibskip
718-
end
719-
720-
# Find a key index with a value distinct from `elt` -- might be `keyidx` itself
721-
keyidx = findprev(!isequal(elt), A, keyidx)
722-
keyidx === nothing && break
723-
end
724-
725-
return h
726-
end
727-
728523

729524
##
730525
# show

src/indices.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,8 @@ IndexStyle(A::AbstractQuasiArray, B::AbstractQuasiArray...) = IndexStyle(IndexSt
1111
IndexStyle(A::AbstractQuasiArray, B::AbstractArray...) = IndexStyle(IndexStyle(A), IndexStyle(B...))
1212

1313

14-
function promote_shape(a::AbstractQuasiArray, b::AbstractQuasiArray)
15-
promote_shape(axes(a), axes(b))
16-
end
17-
18-
const QuasiIndices{N} = NTuple{N,Union{AbstractQuasiVector{<:Number},AbstractVector{<:Number}}}
19-
function promote_shape(a::QuasiIndices, b::QuasiIndices)
14+
promote_shape(a::AbstractQuasiArray, b::AbstractQuasiArray) = promote_shape(axes(a), axes(b))
15+
function promote_shape(a::Tuple, b::Tuple)
2016
if length(a) < length(b)
2117
return promote_shape(b, a)
2218
end

src/lazyquasiarrays.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,10 @@ _BroadcastQuasiArray(bc::Broadcasted) = BroadcastQuasiArray{combine_eltypes(bc.f
124124
BroadcastQuasiArray(bc::Broadcasted{S}) where S =
125125
_BroadcastQuasiArray(instantiate(Broadcasted{S}(bc.f, _broadcast2broadcastarray(bc.args...))))
126126
BroadcastQuasiArray(b::BroadcastQuasiArray) = b
127-
BroadcastQuasiArray(f, A, As...) = BroadcastQuasiArray(instantiate(broadcasted(f, A, As...)))
127+
BroadcastQuasiArray(f, A, As...) = BroadcastQuasiArray{combine_eltypes(f, (A, As...))}(f, A, As...)
128128
BroadcastQuasiVector(f, A, As...) = BroadcastQuasiVector{combine_eltypes(f, (A, As...))}(f, A, As...)
129129
BroadcastQuasiMatrix(f, A, As...) = BroadcastQuasiMatrix{combine_eltypes(f, (A, As...))}(f, A, As...)
130-
BroadcastQuasiArray{T}(f, A, As...) where {T} = BroadcastQuasiArray{T}(instantiate(broadcasted(f, A, As...)))
130+
BroadcastQuasiArray{T}(f, A, As...) where {T} = BroadcastQuasiArray{T,length(axes(broadcasted(f, A, As...)))}(f, A, As...)
131131
BroadcastQuasiArray{T,N}(f, A, As...) where {T,N} = BroadcastQuasiArray{T,N,typeof(f),typeof((A, As...))}(f, (A, As...))
132132

133133
@inline BroadcastQuasiArray(A::AbstractQuasiArray) = BroadcastQuasiArray(call(A), arguments(A)...)
@@ -201,6 +201,9 @@ _broadcast_mul_arguments(a, B) = __broadcast_mul_arguments(a, _mul_arguments(B).
201201
_mul_arguments(A::BroadcastQuasiMatrix{<:Any,typeof(*),<:Tuple{AbstractQuasiVector,AbstractQuasiMatrix}}) =
202202
_broadcast_mul_arguments(A.args...)
203203

204+
broadcasted(::LazyQuasiArrayStyle{2}, ::typeof(*), a::AbstractQuasiVector, B::ApplyQuasiMatrix{<:Any,typeof(*)}) =
205+
*(_broadcast_mul_arguments(a, B)...)
206+
204207
ndims(M::Applied{LazyQuasiArrayApplyStyle,typeof(*)}) = ndims(last(M.args))
205208

206209
call(a::AbstractQuasiArray) = call(MemoryLayout(typeof(a)), a)
@@ -230,3 +233,10 @@ function *(App::ApplyQuasiMatrix{<:Any,typeof(^),<:Tuple{<:AbstractQuasiMatrix{T
230233
p == 0 && return copy(b)
231234
return A*(ApplyQuasiMatrix(^,A,p-1)*b)
232235
end
236+
237+
function simplifiable(::typeof(*), App::ApplyQuasiMatrix{<:Any,typeof(^),<:Tuple{<:AbstractQuasiMatrix{T},<:Integer}}, b::AbstractQuasiArray) where T
238+
A,p = arguments(App)
239+
simplifiable(*, A, b)
240+
end
241+
242+

src/matmul.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,13 @@ copy(M::Mul{<:Any,QuasiArrayLayout}) = _quasi_mul(M, axes(M))
9999
copy(M::Mul{<:AbstractLazyLayout,QuasiArrayLayout}) = ApplyQuasiArray(M)
100100
copy(M::Mul{ApplyLayout{typeof(\)},QuasiArrayLayout}) = ApplyQuasiArray(M)
101101
copy(M::Mul{QuasiArrayLayout,<:AbstractLazyLayout}) = ApplyQuasiArray(M)
102+
for op in (:+, :-)
103+
@eval begin
104+
copy(M::Mul{BroadcastLayout{typeof($op)},QuasiArrayLayout}) = simplify(M)
105+
copy(M::Mul{QuasiArrayLayout,BroadcastLayout{typeof($op)}}) = simplify(M)
106+
end
107+
end
108+
@inline copy(M::Mul{BroadcastLayout{typeof(*)},QuasiArrayLayout}) = copy(Mul{BroadcastLayout{typeof(*)},UnknownLayout}(M.A,M.B))
102109

103110

104111
LazyArrays._vec_mul_view(a::AbstractQuasiVector, kr, ::Colon) = view(a, kr)

src/quasiarray.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ QuasiVector(par::AbstractVector{T}, axes::Tuple{AbstractVector}) where T =
3333

3434
QuasiArray{T}(par::AbstractArray{<:Any,N}, axes::NTuple{N,AbstractVector}) where {T,N} =
3535
QuasiArray{T,N,typeof(axes)}(par, axes)
36+
QuasiArray{T}(par::AbstractArray{T,N}, axes::NTuple{N,AbstractVector}) where {T,N} =
37+
QuasiArray{T,N,typeof(axes)}(par, axes)
3638
QuasiMatrix{T}(par::AbstractMatrix, axes::NTuple{2,AbstractVector}) where T =
3739
QuasiMatrix{T,typeof(axes)}(par, axes)
3840
QuasiVector{T}(par::AbstractVector, axes::Tuple{AbstractVector}) where T =
@@ -81,7 +83,7 @@ QuasiVector(par::AbstractVector{T}, axes::AbstractArray) where {T} =
8183
QuasiVector(par, (axes,))
8284

8385

84-
QuasiArray(a::AbstractQuasiArray) = QuasiArray(a[map(collect,axes(a))...], axes(a))
86+
QuasiArray(a::AbstractQuasiArray{T}) where T = QuasiArray{T}(a[map(collect,axes(a))...], axes(a))
8587
QuasiArray{T}(a::AbstractQuasiArray) where T = QuasiArray(convert(AbstractArray{T},a[map(collect,axes(a))...]), axes(a))
8688
QuasiArray{T,N}(a::AbstractQuasiArray{<:Any,N}) where {T,N} = QuasiArray(convert(AbstractArray{T,N},a[map(collect,axes(a))...]), axes(a))
8789
QuasiArray{T,N,AXES}(a::AbstractQuasiArray{<:Any,N}) where {T,N,AXES} = QuasiArray{T,N,AXES}(Array{T}(a), axes(a))
@@ -134,3 +136,9 @@ end
134136
axes(A) == axes(B) && A.parent == B
135137
==(B::AbstractArray{V,N}, A::QuasiArray{T,N,NTuple{N,OneTo{Int}}}) where {T,V,N} =
136138
A == B
139+
140+
141+
function reshape(A::QuasiVector, ax::Tuple{Any,OneTo{Int}})
142+
@assert ax == (axes(A,1),Base.OneTo(1))
143+
QuasiMatrix(reshape(A.parent,size(A.parent,1),1), (A.axes[1], Base.OneTo(1)))
144+
end

0 commit comments

Comments
 (0)