Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement slicing functionality #20

Merged
merged 16 commits into from
Jul 22, 2023
15 changes: 6 additions & 9 deletions src/Counters.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
using Tensors: Tensor
using Memoize

flops(::Tensor) = 0
@memoize function flops(expr::EinExpr)
flops_sub = sum(flops.(expr.args))

floppi(inds) = mapreduce(i -> size(expr, i), *, inds, init = one(BigInt))
flops_cur = floppi(suminds(expr)) * (isempty(suminds(expr)) && length(expr.args) == 1 ? 0 : floppi(labels(expr)))

return flops_sub + flops_cur
end
flops(expr::EinExpr) =
if isempty(suminds(expr)) && length(expr.args) == 1
0
else
mapreduce(i -> size(expr, i), *, [labels(expr)..., suminds(expr)...])
end

removedsize(::Tensor) = 0
removedsize(expr::EinExpr) = mapreduce(prod ∘ size, +, expr.args) - prod(size(expr))
Expand Down
7 changes: 1 addition & 6 deletions src/EinExpr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,7 @@ Return the size of the `Tensor` resulting from contracting `expr`. If `index` is
"""
Base.size(expr::EinExpr) = tuple((size(expr, i) for i in labels(expr))...)

function Base.size(expr::EinExpr, i::Symbol)
target = findfirst(input -> i ∈ labels(input), expr.args)
isnothing(target) && throw(KeyError(i))

return size(expr.args[target], i)
end
Base.size(expr::EinExpr, i::Symbol) = Iterators.filter(∋(i) ∘ labels, expr) |> first |> x -> size(x, i)

"""
select(expr, i)
Expand Down
3 changes: 3 additions & 0 deletions src/EinExprs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ export suminds, path, select
include("Counters.jl")
export flops, removedsize

include("Slicing.jl")
export findslices, FlopsScorer, SizeScorer

include("Optimizers/Optimizers.jl")
export Optimizer, einexpr
export Exhaustive, Greedy
Expand Down
2 changes: 1 addition & 1 deletion src/Optimizers/Exhaustive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Exhaustive contraction path optimizers. It guarantees to find the optimal contra
The algorithm has a ``\mathcal{O}(n!)`` time complexity if `outer = true` and ``\mathcal{O}(\exp(n))`` if `outer = false`.
"""
@kwdef struct Exhaustive <: Optimizer
metric::Function = flops
metric::Function = path -> mapreduce(flops, +, path)
outer::Bool = false
end

Expand Down
85 changes: 85 additions & 0 deletions src/Slicing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
Base.selectdim(path::EinExpr, index::Symbol, i) = EinExpr(map(path.args) do sub
index ∈ __labels_children(sub) ? selectdim(sub, index, i) : sub
end, filter(!=(index), path.head))

__labels_children(x) = labels(x)
__labels_children(path::EinExpr) = labels(path, all = true)

Base.view(path::EinExpr, cuttings::Pair{Symbol,<:Integer}...) =

Check warning on line 8 in src/Slicing.jl

View check run for this annotation

Codecov / codecov/patch

src/Slicing.jl#L8

Added line #L8 was not covered by tests
reduce(cuttings, init = path) do acc, proj
d, i = proj
selectdim(acc, d, i)

Check warning on line 11 in src/Slicing.jl

View check run for this annotation

Codecov / codecov/patch

src/Slicing.jl#L10-L11

Added lines #L10 - L11 were not covered by tests
end

function findslices(
scorer,
path::EinExpr;
size = nothing,
overhead = nothing,
slices = nothing,
temperature = 0.01,
skip = labels(path),
)
all(isnothing, (size, overhead, slices)) &&
throw(ArgumentError("need to specify at least one size, overhead or slices target"))

candidates = Set(setdiff(mapreduce(labels, ∪, path), skip))
solution = Set{Symbol}()
current = (; slices = 1, size = maximum(prod ∘ Base.size, path), overhead = 1.0)
original_flops = mapreduce(flops, +, path)

sliced_path = path
while !isempty(candidates)
# temperature adds boltzmann like noise
winner = argmax(candidates) do index
scorer(sliced_path, index) - temperature * (log ∘ (-) ∘ log ∘ rand)()
end
delete!(candidates, winner)

sliced_path = selectdim(sliced_path, winner, 1)
cur_overhead =
prod(i -> Base.size(path, i), [solution..., winner]) * mapreduce(flops, +, sliced_path) / original_flops

!isnothing(overhead) && cur_overhead > overhead && break
push!(solution, winner)

current = (;
slices = current.slices * (prod ∘ Base.size)(path, winner),
size = maximum(prod ∘ Base.size, sliced_path),
overhead = cur_overhead,
)

!isnothing(slices) && current.slices >= slices && break
!isnothing(size) && current.size <= size && break
end

return solution
end

abstract type Scorer end

Base.@kwdef struct FlopsScorer <: Scorer
weight::Float64 = 1e-3
end

function (cb::FlopsScorer)(path, index)
slice = selectdim(path, index, 1)

flops_reduction = mapreduce(flops, +, path) - mapreduce(flops, +, slice)
write_reduction = mapreduce(prod ∘ size, +, path) - mapreduce(prod ∘ size, +, slice)

log(flops_reduction + write_reduction * cb.weight + 1)
end

Base.@kwdef struct SizeScorer <: Scorer
weight::Float64 = 1e-3
end

function (cb::SizeScorer)(path, index)
slice = selectdim(path, index, 1)

Check warning on line 79 in src/Slicing.jl

View check run for this annotation

Codecov / codecov/patch

src/Slicing.jl#L78-L79

Added lines #L78 - L79 were not covered by tests

flops_reduction = mapreduce(flops, +, path) - mapreduce(flops, +, slice)
write_reduction = mapreduce(prod ∘ size, +, path) - mapreduce(prod ∘ size, +, slice)

Check warning on line 82 in src/Slicing.jl

View check run for this annotation

Codecov / codecov/patch

src/Slicing.jl#L81-L82

Added lines #L81 - L82 were not covered by tests

log(write_reduction + flops_reduction * cb.weight + 1)

Check warning on line 84 in src/Slicing.jl

View check run for this annotation

Codecov / codecov/patch

src/Slicing.jl#L84

Added line #L84 was not covered by tests
end
2 changes: 1 addition & 1 deletion test/Exhaustive_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
expr = einexpr(Exhaustive, EinExpr(tensors, [:p, :j]))
@test expr isa EinExpr
# TODO traverse through the tree and check everything is ok
@test flops(expr) == 48753
@test mapreduce(flops, +, expr) == 48753
# FIXME non-determinist behaviour on order
@test issetequal(path(expr), [[:q], [:m], [:f, :i], [:g, :l], [:b], [:o], [:c, :e], [:n, :a, :d, :h], [:k]])
end
90 changes: 90 additions & 0 deletions test/Slicing_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
@testset "Slicing" begin
sizes = Dict(
:o => 3,
:b => 7,
:p => 6,
:n => 7,
:j => 9,
:k => 8,
:d => 4,
:e => 2,
:c => 2,
:h => 5,
:i => 5,
:l => 10,
:m => 7,
:q => 5,
:a => 3,
:f => 7,
:g => 3,
)

expr = EinExpr(
[
EinExpr(
[
EinExpr(
[
EinExpr(
[
EinExpr(
[
EinExpr(
[
EinExpr(
[
EinExpr(
[
Tensor(
ones((sizes[i] for i in [:m, :f, :q])...),
[:m, :f, :q],
),
Tensor(
ones((sizes[i] for i in [:g, :q])...),
[:g, :q],
),
],
[:m, :f, :g],
),
Tensor(
ones((sizes[i] for i in [:o, :i, :m, :c])...),
[:o, :i, :m, :c],
),
],
[:f, :g, :o, :i, :c],
),
Tensor(ones((sizes[i] for i in [:f, :l, :i])...), [:f, :l, :i]),
],
[:g, :o, :c, :l],
),
Tensor(ones((sizes[i] for i in [:g, :n, :l, :a])...), [:g, :n, :l, :a]),
],
[:o, :c, :n, :a],
),
EinExpr(
[
Tensor(ones((sizes[i] for i in [:b, :e])...), [:b, :e]),
Tensor(ones((sizes[i] for i in [:d, :b, :o])...), [:d, :b, :o]),
],
[:e, :d, :o],
),
],
[:c, :n, :a, :e, :d],
),
Tensor(ones((sizes[i] for i in [:c, :e, :h])...), [:c, :e, :h]),
],
[:n, :a, :d, :h],
),
Tensor(ones((sizes[i] for i in [:k, :d, :h, :a, :n, :j])...), [:k, :d, :h, :a, :n, :j]),
],
[:k, :j],
),
Tensor(ones((sizes[i] for i in [:p, :k])...), [:p, :k]),
],
[:p, :j],
)

cuttings = findslices(FlopsScorer(), expr, slices = 1000)

@test prod(i -> size(expr, i), cuttings) >= 1000
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using EinExprs
@testset "Optimizers" verbose = true begin
include("Exhaustive_test.jl")
end
include("Slicing_test.jl")
end

@testset "Integration tests" verbose = true begin
Expand Down
Loading