diff --git a/src/Counters.jl b/src/Counters.jl index 1ce4cc4..1711cf7 100644 --- a/src/Counters.jl +++ b/src/Counters.jl @@ -1,15 +1,12 @@ using Tensors: Tensor -using Memoize flops(::Tensor) = 0 -@memoize function flops(expr::EinExpr) - flops_sub = sum(flops.(expr.args)) - - floppi(inds) = mapreduce(i -> size(expr, i), *, inds, init = one(BigInt)) - flops_cur = floppi(suminds(expr)) * (isempty(suminds(expr)) && length(expr.args) == 1 ? 0 : floppi(labels(expr))) - - return flops_sub + flops_cur -end +flops(expr::EinExpr) = + if isempty(suminds(expr)) && length(expr.args) == 1 + 0 + else + mapreduce(i -> size(expr, i), *, [labels(expr)..., suminds(expr)...]) + end removedsize(::Tensor) = 0 removedsize(expr::EinExpr) = mapreduce(prod ∘ size, +, expr.args) - prod(size(expr)) diff --git a/src/EinExpr.jl b/src/EinExpr.jl index acaef24..48c7f42 100644 --- a/src/EinExpr.jl +++ b/src/EinExpr.jl @@ -40,12 +40,7 @@ Return the size of the `Tensor` resulting from contracting `expr`. If `index` is """ Base.size(expr::EinExpr) = tuple((size(expr, i) for i in labels(expr))...) -function Base.size(expr::EinExpr, i::Symbol) - target = findfirst(input -> i ∈ labels(input), expr.args) - isnothing(target) && throw(KeyError(i)) - - return size(expr.args[target], i) -end +Base.size(expr::EinExpr, i::Symbol) = Iterators.filter(∋(i) ∘ labels, expr) |> first |> x -> size(x, i) """ select(expr, i) diff --git a/src/EinExprs.jl b/src/EinExprs.jl index 3431a85..3d1676a 100644 --- a/src/EinExprs.jl +++ b/src/EinExprs.jl @@ -7,6 +7,9 @@ export suminds, path, select include("Counters.jl") export flops, removedsize +include("Slicing.jl") +export findslices, FlopsScorer, SizeScorer + include("Optimizers/Optimizers.jl") export Optimizer, einexpr export Exhaustive, Greedy diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index 4c697d1..00acb80 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -17,7 +17,7 @@ Exhaustive contraction path optimizers. It guarantees to find the optimal contra The algorithm has a ``\mathcal{O}(n!)`` time complexity if `outer = true` and ``\mathcal{O}(\exp(n))`` if `outer = false`. """ @kwdef struct Exhaustive <: Optimizer - metric::Function = flops + metric::Function = path -> mapreduce(flops, +, path) outer::Bool = false end diff --git a/src/Slicing.jl b/src/Slicing.jl new file mode 100644 index 0000000..4cd6dcc --- /dev/null +++ b/src/Slicing.jl @@ -0,0 +1,85 @@ +Base.selectdim(path::EinExpr, index::Symbol, i) = EinExpr(map(path.args) do sub + index ∈ __labels_children(sub) ? selectdim(sub, index, i) : sub +end, filter(!=(index), path.head)) + +__labels_children(x) = labels(x) +__labels_children(path::EinExpr) = labels(path, all = true) + +Base.view(path::EinExpr, cuttings::Pair{Symbol,<:Integer}...) = + reduce(cuttings, init = path) do acc, proj + d, i = proj + selectdim(acc, d, i) + end + +function findslices( + scorer, + path::EinExpr; + size = nothing, + overhead = nothing, + slices = nothing, + temperature = 0.01, + skip = labels(path), +) + all(isnothing, (size, overhead, slices)) && + throw(ArgumentError("need to specify at least one size, overhead or slices target")) + + candidates = Set(setdiff(mapreduce(labels, ∪, path), skip)) + solution = Set{Symbol}() + current = (; slices = 1, size = maximum(prod ∘ Base.size, path), overhead = 1.0) + original_flops = mapreduce(flops, +, path) + + sliced_path = path + while !isempty(candidates) + # temperature adds boltzmann like noise + winner = argmax(candidates) do index + scorer(sliced_path, index) - temperature * (log ∘ (-) ∘ log ∘ rand)() + end + delete!(candidates, winner) + + sliced_path = selectdim(sliced_path, winner, 1) + cur_overhead = + prod(i -> Base.size(path, i), [solution..., winner]) * mapreduce(flops, +, sliced_path) / original_flops + + !isnothing(overhead) && cur_overhead > overhead && break + push!(solution, winner) + + current = (; + slices = current.slices * (prod ∘ Base.size)(path, winner), + size = maximum(prod ∘ Base.size, sliced_path), + overhead = cur_overhead, + ) + + !isnothing(slices) && current.slices >= slices && break + !isnothing(size) && current.size <= size && break + end + + return solution +end + +abstract type Scorer end + +Base.@kwdef struct FlopsScorer <: Scorer + weight::Float64 = 1e-3 +end + +function (cb::FlopsScorer)(path, index) + slice = selectdim(path, index, 1) + + flops_reduction = mapreduce(flops, +, path) - mapreduce(flops, +, slice) + write_reduction = mapreduce(prod ∘ size, +, path) - mapreduce(prod ∘ size, +, slice) + + log(flops_reduction + write_reduction * cb.weight + 1) +end + +Base.@kwdef struct SizeScorer <: Scorer + weight::Float64 = 1e-3 +end + +function (cb::SizeScorer)(path, index) + slice = selectdim(path, index, 1) + + flops_reduction = mapreduce(flops, +, path) - mapreduce(flops, +, slice) + write_reduction = mapreduce(prod ∘ size, +, path) - mapreduce(prod ∘ size, +, slice) + + log(write_reduction + flops_reduction * cb.weight + 1) +end diff --git a/test/Exhaustive_test.jl b/test/Exhaustive_test.jl index 553cc6a..1979a92 100644 --- a/test/Exhaustive_test.jl +++ b/test/Exhaustive_test.jl @@ -35,7 +35,7 @@ expr = einexpr(Exhaustive, EinExpr(tensors, [:p, :j])) @test expr isa EinExpr # TODO traverse through the tree and check everything is ok - @test flops(expr) == 48753 + @test mapreduce(flops, +, expr) == 48753 # FIXME non-determinist behaviour on order @test issetequal(path(expr), [[:q], [:m], [:f, :i], [:g, :l], [:b], [:o], [:c, :e], [:n, :a, :d, :h], [:k]]) end diff --git a/test/Slicing_test.jl b/test/Slicing_test.jl new file mode 100644 index 0000000..e0f126d --- /dev/null +++ b/test/Slicing_test.jl @@ -0,0 +1,90 @@ +@testset "Slicing" begin + sizes = Dict( + :o => 3, + :b => 7, + :p => 6, + :n => 7, + :j => 9, + :k => 8, + :d => 4, + :e => 2, + :c => 2, + :h => 5, + :i => 5, + :l => 10, + :m => 7, + :q => 5, + :a => 3, + :f => 7, + :g => 3, + ) + + expr = EinExpr( + [ + EinExpr( + [ + EinExpr( + [ + EinExpr( + [ + EinExpr( + [ + EinExpr( + [ + EinExpr( + [ + EinExpr( + [ + Tensor( + ones((sizes[i] for i in [:m, :f, :q])...), + [:m, :f, :q], + ), + Tensor( + ones((sizes[i] for i in [:g, :q])...), + [:g, :q], + ), + ], + [:m, :f, :g], + ), + Tensor( + ones((sizes[i] for i in [:o, :i, :m, :c])...), + [:o, :i, :m, :c], + ), + ], + [:f, :g, :o, :i, :c], + ), + Tensor(ones((sizes[i] for i in [:f, :l, :i])...), [:f, :l, :i]), + ], + [:g, :o, :c, :l], + ), + Tensor(ones((sizes[i] for i in [:g, :n, :l, :a])...), [:g, :n, :l, :a]), + ], + [:o, :c, :n, :a], + ), + EinExpr( + [ + Tensor(ones((sizes[i] for i in [:b, :e])...), [:b, :e]), + Tensor(ones((sizes[i] for i in [:d, :b, :o])...), [:d, :b, :o]), + ], + [:e, :d, :o], + ), + ], + [:c, :n, :a, :e, :d], + ), + Tensor(ones((sizes[i] for i in [:c, :e, :h])...), [:c, :e, :h]), + ], + [:n, :a, :d, :h], + ), + Tensor(ones((sizes[i] for i in [:k, :d, :h, :a, :n, :j])...), [:k, :d, :h, :a, :n, :j]), + ], + [:k, :j], + ), + Tensor(ones((sizes[i] for i in [:p, :k])...), [:p, :k]), + ], + [:p, :j], + ) + + cuttings = findslices(FlopsScorer(), expr, slices = 1000) + + @test prod(i -> size(expr, i), cuttings) >= 1000 +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 7e4063a..88b4088 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,7 @@ using EinExprs @testset "Optimizers" verbose = true begin include("Exhaustive_test.jl") end + include("Slicing_test.jl") end @testset "Integration tests" verbose = true begin