From a645158f716162a590452b0c6bb589331e8892ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 20 Jun 2023 14:31:32 +0200 Subject: [PATCH 01/16] Init slicing code --- src/EinExprs.jl | 2 ++ src/Slicing.jl | 12 ++++++++++++ 2 files changed, 14 insertions(+) create mode 100644 src/Slicing.jl diff --git a/src/EinExprs.jl b/src/EinExprs.jl index 3431a85..ba25b1c 100644 --- a/src/EinExprs.jl +++ b/src/EinExprs.jl @@ -7,6 +7,8 @@ export suminds, path, select include("Counters.jl") export flops, removedsize +include("Slicing.jl") + include("Optimizers/Optimizers.jl") export Optimizer, einexpr export Exhaustive, Greedy diff --git a/src/Slicing.jl b/src/Slicing.jl new file mode 100644 index 0000000..403fc7c --- /dev/null +++ b/src/Slicing.jl @@ -0,0 +1,12 @@ +Base.selectdim(path::EinExpr, index::Symbol, i) = EinExpr(map(path.args) do sub + index ∈ __labels_children(sub) ? selectdim(sub, index, i) : sub + end, filter(!=(index), path.head)) + +__labels_children(x) = labels(x) +__labels_children(path::EinExpr) = labels(path, all=true) + +Base.view(path::EinExpr, cuttings::Pair{Symbol,<:Integer}...) = + reduce(cuttings, init=path) do acc, proj + d, i = proj + selectdim(acc, d, i) + end \ No newline at end of file From 2e46e404279641f2a4cd34c5b1ed46d16a313082 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 19 Jul 2023 21:59:59 +0200 Subject: [PATCH 02/16] Prototype the `slices` function --- src/EinExprs.jl | 1 + src/Slicing.jl | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/EinExprs.jl b/src/EinExprs.jl index ba25b1c..be87417 100644 --- a/src/EinExprs.jl +++ b/src/EinExprs.jl @@ -8,6 +8,7 @@ include("Counters.jl") export flops, removedsize include("Slicing.jl") +export slices include("Optimizers/Optimizers.jl") export Optimizer, einexpr diff --git a/src/Slicing.jl b/src/Slicing.jl index 403fc7c..1fcebcd 100644 --- a/src/Slicing.jl +++ b/src/Slicing.jl @@ -9,4 +9,37 @@ Base.view(path::EinExpr, cuttings::Pair{Symbol,<:Integer}...) = reduce(cuttings, init=path) do acc, proj d, i = proj selectdim(acc, d, i) - end \ No newline at end of file + end + +function slices( + target::Function, + path::EinExpr; + size=nothing, + overhead=nothing, + slices=nothing, + temperature=0.01, + skip=Set{Symbol}() +) + candidates = setdiff(labels(path, all=true), skip) + solution = Set{Symbol}() + + current = (; slices=1, size=..., overhead=1.0) + + checkpredicates() = !isnothing(size) && ... || !isnothing(slices) && ... || !isnothing(overhead) && ... + + while checkpredicates() + winner = maximum(candidates) do index + # score + boltzmann sampling + target(...) - temperature * (log ∘ (-) ∘ log ∘ rand) + end + + push!(winner, solution) + current = (; + slices=current.slices * size(path, winner), + size=..., + overhead=... + ) + end + + return solution +end From 332706a06fa6ef330cbda4ba0efb3a111331cf92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 20 Jul 2023 00:00:15 +0200 Subject: [PATCH 03/16] Implement `slices` --- src/Slicing.jl | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/Slicing.jl b/src/Slicing.jl index 1fcebcd..8fe1974 100644 --- a/src/Slicing.jl +++ b/src/Slicing.jl @@ -22,23 +22,24 @@ function slices( ) candidates = setdiff(labels(path, all=true), skip) solution = Set{Symbol}() + current = (; slices=1, size=maximum(size, path), overhead=1.0) - current = (; slices=1, size=..., overhead=1.0) - - checkpredicates() = !isnothing(size) && ... || !isnothing(slices) && ... || !isnothing(overhead) && ... - - while checkpredicates() + sliced_path = path + while !(!isnothing(slices) && current.slices >= slices || !isnothing(size) && current.size <= size) + # temperature adds boltzmann like noise winner = maximum(candidates) do index - # score + boltzmann sampling - target(...) - temperature * (log ∘ (-) ∘ log ∘ rand) + target(sliced_path, index) - temperature * (log ∘ (-) ∘ log ∘ rand)() end - push!(winner, solution) + sliced_path = selectdim(sliced_path, winner, 1) current = (; slices=current.slices * size(path, winner), - size=..., - overhead=... + size=maximum(size, sliced_path), + overhead=flops(sliced_path) / flops(path) ) + + !isnothing(overhead) && current.overhead > overhead && break + push!(winner, solution) end return solution From d903249cbb70b35881abe7f171f959329ebff859 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 20 Jul 2023 00:01:56 +0200 Subject: [PATCH 04/16] Cache naively `flops` computation --- src/Slicing.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Slicing.jl b/src/Slicing.jl index 8fe1974..215684e 100644 --- a/src/Slicing.jl +++ b/src/Slicing.jl @@ -23,6 +23,7 @@ function slices( candidates = setdiff(labels(path, all=true), skip) solution = Set{Symbol}() current = (; slices=1, size=maximum(size, path), overhead=1.0) + original_flops = flops(path) sliced_path = path while !(!isnothing(slices) && current.slices >= slices || !isnothing(size) && current.size <= size) @@ -35,7 +36,7 @@ function slices( current = (; slices=current.slices * size(path, winner), size=maximum(size, sliced_path), - overhead=flops(sliced_path) / flops(path) + overhead=flops(sliced_path) / original_flops ) !isnothing(overhead) && current.overhead > overhead && break From f9713879bae2edc42c3bda8a7f8085d97399e757 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 20 Jul 2023 00:02:52 +0200 Subject: [PATCH 05/16] Format code --- src/Slicing.jl | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/src/Slicing.jl b/src/Slicing.jl index 215684e..caae75b 100644 --- a/src/Slicing.jl +++ b/src/Slicing.jl @@ -1,12 +1,15 @@ -Base.selectdim(path::EinExpr, index::Symbol, i) = EinExpr(map(path.args) do sub +Base.selectdim(path::EinExpr, index::Symbol, i) = EinExpr( + map(path.args) do sub index ∈ __labels_children(sub) ? selectdim(sub, index, i) : sub - end, filter(!=(index), path.head)) + end, + filter(!=(index), path.head), +) __labels_children(x) = labels(x) -__labels_children(path::EinExpr) = labels(path, all=true) +__labels_children(path::EinExpr) = labels(path, all = true) Base.view(path::EinExpr, cuttings::Pair{Symbol,<:Integer}...) = - reduce(cuttings, init=path) do acc, proj + reduce(cuttings, init = path) do acc, proj d, i = proj selectdim(acc, d, i) end @@ -14,19 +17,22 @@ Base.view(path::EinExpr, cuttings::Pair{Symbol,<:Integer}...) = function slices( target::Function, path::EinExpr; - size=nothing, - overhead=nothing, - slices=nothing, - temperature=0.01, - skip=Set{Symbol}() + size = nothing, + overhead = nothing, + slices = nothing, + temperature = 0.01, + skip = Set{Symbol}(), ) - candidates = setdiff(labels(path, all=true), skip) + candidates = setdiff(labels(path, all = true), skip) solution = Set{Symbol}() - current = (; slices=1, size=maximum(size, path), overhead=1.0) + current = (; slices = 1, size = maximum(size, path), overhead = 1.0) original_flops = flops(path) sliced_path = path - while !(!isnothing(slices) && current.slices >= slices || !isnothing(size) && current.size <= size) + while !( + !isnothing(slices) && current.slices >= slices || + !isnothing(size) && current.size <= size + ) # temperature adds boltzmann like noise winner = maximum(candidates) do index target(sliced_path, index) - temperature * (log ∘ (-) ∘ log ∘ rand)() @@ -34,9 +40,9 @@ function slices( sliced_path = selectdim(sliced_path, winner, 1) current = (; - slices=current.slices * size(path, winner), - size=maximum(size, sliced_path), - overhead=flops(sliced_path) / original_flops + slices = current.slices * size(path, winner), + size = maximum(size, sliced_path), + overhead = flops(sliced_path) / original_flops, ) !isnothing(overhead) && current.overhead > overhead && break From 70cb13c691912dc0750356584e4486562b3646a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 20 Jul 2023 00:16:45 +0200 Subject: [PATCH 06/16] Relax `target` type --- src/Slicing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Slicing.jl b/src/Slicing.jl index caae75b..ccf7c9e 100644 --- a/src/Slicing.jl +++ b/src/Slicing.jl @@ -15,7 +15,7 @@ Base.view(path::EinExpr, cuttings::Pair{Symbol,<:Integer}...) = end function slices( - target::Function, + target, path::EinExpr; size = nothing, overhead = nothing, From b885a643e6d22014bd374461e42f1502e780c29b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 21 Jul 2023 17:08:29 +0200 Subject: [PATCH 07/16] Refactor `flops` counter --- src/Counters.jl | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/Counters.jl b/src/Counters.jl index 1ce4cc4..fb4b07c 100644 --- a/src/Counters.jl +++ b/src/Counters.jl @@ -1,15 +1,7 @@ using Tensors: Tensor -using Memoize flops(::Tensor) = 0 -@memoize function flops(expr::EinExpr) - flops_sub = sum(flops.(expr.args)) - - floppi(inds) = mapreduce(i -> size(expr, i), *, inds, init = one(BigInt)) - flops_cur = floppi(suminds(expr)) * (isempty(suminds(expr)) && length(expr.args) == 1 ? 0 : floppi(labels(expr))) - - return flops_sub + flops_cur -end +flops(expr::EinExpr) = mapreduce(i -> size(expr, i), *, [labels(expr)..., suminds(expr)...]) removedsize(::Tensor) = 0 removedsize(expr::EinExpr) = mapreduce(prod ∘ size, +, expr.args) - prod(size(expr)) From 73db26740e78e23bcc5e68ef019c102c8ddbfdc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 21 Jul 2023 18:02:04 +0200 Subject: [PATCH 08/16] Format code --- src/Slicing.jl | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/Slicing.jl b/src/Slicing.jl index ccf7c9e..1f25e82 100644 --- a/src/Slicing.jl +++ b/src/Slicing.jl @@ -1,9 +1,6 @@ -Base.selectdim(path::EinExpr, index::Symbol, i) = EinExpr( - map(path.args) do sub - index ∈ __labels_children(sub) ? selectdim(sub, index, i) : sub - end, - filter(!=(index), path.head), -) +Base.selectdim(path::EinExpr, index::Symbol, i) = EinExpr(map(path.args) do sub + index ∈ __labels_children(sub) ? selectdim(sub, index, i) : sub +end, filter(!=(index), path.head)) __labels_children(x) = labels(x) __labels_children(path::EinExpr) = labels(path, all = true) @@ -15,7 +12,7 @@ Base.view(path::EinExpr, cuttings::Pair{Symbol,<:Integer}...) = end function slices( - target, + scorer, path::EinExpr; size = nothing, overhead = nothing, @@ -29,13 +26,10 @@ function slices( original_flops = flops(path) sliced_path = path - while !( - !isnothing(slices) && current.slices >= slices || - !isnothing(size) && current.size <= size - ) + while !(!isnothing(slices) && current.slices >= slices || !isnothing(size) && current.size <= size) # temperature adds boltzmann like noise winner = maximum(candidates) do index - target(sliced_path, index) - temperature * (log ∘ (-) ∘ log ∘ rand)() + scorer(sliced_path, index) - temperature * (log ∘ (-) ∘ log ∘ rand)() end sliced_path = selectdim(sliced_path, winner, 1) From 41d4ddd542cb73032130be2e9787fd0d32e92b09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 21 Jul 2023 18:02:49 +0200 Subject: [PATCH 09/16] Implement slicing scorer methods --- src/Slicing.jl | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/Slicing.jl b/src/Slicing.jl index 1f25e82..6701cba 100644 --- a/src/Slicing.jl +++ b/src/Slicing.jl @@ -45,3 +45,31 @@ function slices( return solution end + +abstract type Scorer end + +Base.@kwdef struct FlopsScorer <: Scorer + weight::Float64 = 1e-3 +end + +function (cb::FlopsScorer)(path, index) + slice = selectdim(path, index, 1) + + flops_reduction = mapreduce(flops, +, path) - mapreduce(flops, +, slice) + write_reduction = mapreduce(size, +, path) - mapreduce(size, +, slice) + + log(flops_reduction + write_reduction * cb.weight + 1) +end + +Base.@kwdef struct SizeScorer <: Scorer + weight::Float64 = 1e-3 +end + +function (cb::SizeScorer)(path, index) + slice = selectdim(path, index, 1) + + flops_reduction = mapreduce(flops, +, path) - mapreduce(flops, +, slice) + write_reduction = mapreduce(size, +, path) - mapreduce(size, +, slice) + + log(write_reduction + flops_reduction * cb.weight + 1) +end From bcf297b14274926910df49c0b50c6518dfacfaa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 21 Jul 2023 18:04:07 +0200 Subject: [PATCH 10/16] Rename `slices` to `findslices` --- src/EinExprs.jl | 2 +- src/Slicing.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/EinExprs.jl b/src/EinExprs.jl index be87417..3d1676a 100644 --- a/src/EinExprs.jl +++ b/src/EinExprs.jl @@ -8,7 +8,7 @@ include("Counters.jl") export flops, removedsize include("Slicing.jl") -export slices +export findslices, FlopsScorer, SizeScorer include("Optimizers/Optimizers.jl") export Optimizer, einexpr diff --git a/src/Slicing.jl b/src/Slicing.jl index 6701cba..99826a2 100644 --- a/src/Slicing.jl +++ b/src/Slicing.jl @@ -11,7 +11,7 @@ Base.view(path::EinExpr, cuttings::Pair{Symbol,<:Integer}...) = selectdim(acc, d, i) end -function slices( +function findslices( scorer, path::EinExpr; size = nothing, From 39e05fc7bedc79fcd1c8d41935ccd00689513724 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 21 Jul 2023 21:15:11 +0200 Subject: [PATCH 11/16] Fix edge case in `flops` --- src/Counters.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/Counters.jl b/src/Counters.jl index fb4b07c..1711cf7 100644 --- a/src/Counters.jl +++ b/src/Counters.jl @@ -1,7 +1,12 @@ using Tensors: Tensor flops(::Tensor) = 0 -flops(expr::EinExpr) = mapreduce(i -> size(expr, i), *, [labels(expr)..., suminds(expr)...]) +flops(expr::EinExpr) = + if isempty(suminds(expr)) && length(expr.args) == 1 + 0 + else + mapreduce(i -> size(expr, i), *, [labels(expr)..., suminds(expr)...]) + end removedsize(::Tensor) = 0 removedsize(expr::EinExpr) = mapreduce(prod ∘ size, +, expr.args) - prod(size(expr)) From c562370cd408e616b98c12dcc064cb63b2be471a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 21 Jul 2023 21:26:00 +0200 Subject: [PATCH 12/16] Fix `metric` computation in `Exhaustive` optimizer --- src/Optimizers/Exhaustive.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index 4c697d1..00acb80 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -17,7 +17,7 @@ Exhaustive contraction path optimizers. It guarantees to find the optimal contra The algorithm has a ``\mathcal{O}(n!)`` time complexity if `outer = true` and ``\mathcal{O}(\exp(n))`` if `outer = false`. """ @kwdef struct Exhaustive <: Optimizer - metric::Function = flops + metric::Function = path -> mapreduce(flops, +, path) outer::Bool = false end From df38a16eba99345912bf86f2590194926e7c2015 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 21 Jul 2023 21:33:14 +0200 Subject: [PATCH 13/16] Fix `Exhaustive` test --- test/Exhaustive_test.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Exhaustive_test.jl b/test/Exhaustive_test.jl index 553cc6a..1979a92 100644 --- a/test/Exhaustive_test.jl +++ b/test/Exhaustive_test.jl @@ -35,7 +35,7 @@ expr = einexpr(Exhaustive, EinExpr(tensors, [:p, :j])) @test expr isa EinExpr # TODO traverse through the tree and check everything is ok - @test flops(expr) == 48753 + @test mapreduce(flops, +, expr) == 48753 # FIXME non-determinist behaviour on order @test issetequal(path(expr), [[:q], [:m], [:f, :i], [:g, :l], [:b], [:o], [:c, :e], [:n, :a, :d, :h], [:k]]) end From 4fc26e18a43de0d001137d5f3c20e35d045db6d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 22 Jul 2023 03:38:16 +0200 Subject: [PATCH 14/16] Rewrite `size` method on index --- src/EinExpr.jl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/EinExpr.jl b/src/EinExpr.jl index acaef24..48c7f42 100644 --- a/src/EinExpr.jl +++ b/src/EinExpr.jl @@ -40,12 +40,7 @@ Return the size of the `Tensor` resulting from contracting `expr`. If `index` is """ Base.size(expr::EinExpr) = tuple((size(expr, i) for i in labels(expr))...) -function Base.size(expr::EinExpr, i::Symbol) - target = findfirst(input -> i ∈ labels(input), expr.args) - isnothing(target) && throw(KeyError(i)) - - return size(expr.args[target], i) -end +Base.size(expr::EinExpr, i::Symbol) = Iterators.filter(∋(i) ∘ labels, expr) |> first |> x -> size(x, i) """ select(expr, i) From 6de2f79167b9b663287b7b838b51760af9a68722 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 22 Jul 2023 03:38:33 +0200 Subject: [PATCH 15/16] Fix `findslices` --- src/Slicing.jl | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/src/Slicing.jl b/src/Slicing.jl index 99826a2..4cd6dcc 100644 --- a/src/Slicing.jl +++ b/src/Slicing.jl @@ -18,29 +18,39 @@ function findslices( overhead = nothing, slices = nothing, temperature = 0.01, - skip = Set{Symbol}(), + skip = labels(path), ) - candidates = setdiff(labels(path, all = true), skip) + all(isnothing, (size, overhead, slices)) && + throw(ArgumentError("need to specify at least one size, overhead or slices target")) + + candidates = Set(setdiff(mapreduce(labels, ∪, path), skip)) solution = Set{Symbol}() - current = (; slices = 1, size = maximum(size, path), overhead = 1.0) - original_flops = flops(path) + current = (; slices = 1, size = maximum(prod ∘ Base.size, path), overhead = 1.0) + original_flops = mapreduce(flops, +, path) sliced_path = path - while !(!isnothing(slices) && current.slices >= slices || !isnothing(size) && current.size <= size) + while !isempty(candidates) # temperature adds boltzmann like noise - winner = maximum(candidates) do index + winner = argmax(candidates) do index scorer(sliced_path, index) - temperature * (log ∘ (-) ∘ log ∘ rand)() end + delete!(candidates, winner) sliced_path = selectdim(sliced_path, winner, 1) + cur_overhead = + prod(i -> Base.size(path, i), [solution..., winner]) * mapreduce(flops, +, sliced_path) / original_flops + + !isnothing(overhead) && cur_overhead > overhead && break + push!(solution, winner) + current = (; - slices = current.slices * size(path, winner), - size = maximum(size, sliced_path), - overhead = flops(sliced_path) / original_flops, + slices = current.slices * (prod ∘ Base.size)(path, winner), + size = maximum(prod ∘ Base.size, sliced_path), + overhead = cur_overhead, ) - !isnothing(overhead) && current.overhead > overhead && break - push!(winner, solution) + !isnothing(slices) && current.slices >= slices && break + !isnothing(size) && current.size <= size && break end return solution @@ -56,7 +66,7 @@ function (cb::FlopsScorer)(path, index) slice = selectdim(path, index, 1) flops_reduction = mapreduce(flops, +, path) - mapreduce(flops, +, slice) - write_reduction = mapreduce(size, +, path) - mapreduce(size, +, slice) + write_reduction = mapreduce(prod ∘ size, +, path) - mapreduce(prod ∘ size, +, slice) log(flops_reduction + write_reduction * cb.weight + 1) end @@ -69,7 +79,7 @@ function (cb::SizeScorer)(path, index) slice = selectdim(path, index, 1) flops_reduction = mapreduce(flops, +, path) - mapreduce(flops, +, slice) - write_reduction = mapreduce(size, +, path) - mapreduce(size, +, slice) + write_reduction = mapreduce(prod ∘ size, +, path) - mapreduce(prod ∘ size, +, slice) log(write_reduction + flops_reduction * cb.weight + 1) end From 638779281f5dab86f2ba4555ca2811669ee8ecb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 22 Jul 2023 03:42:56 +0200 Subject: [PATCH 16/16] Test slicing --- test/Slicing_test.jl | 90 ++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 2 files changed, 91 insertions(+) create mode 100644 test/Slicing_test.jl diff --git a/test/Slicing_test.jl b/test/Slicing_test.jl new file mode 100644 index 0000000..e0f126d --- /dev/null +++ b/test/Slicing_test.jl @@ -0,0 +1,90 @@ +@testset "Slicing" begin + sizes = Dict( + :o => 3, + :b => 7, + :p => 6, + :n => 7, + :j => 9, + :k => 8, + :d => 4, + :e => 2, + :c => 2, + :h => 5, + :i => 5, + :l => 10, + :m => 7, + :q => 5, + :a => 3, + :f => 7, + :g => 3, + ) + + expr = EinExpr( + [ + EinExpr( + [ + EinExpr( + [ + EinExpr( + [ + EinExpr( + [ + EinExpr( + [ + EinExpr( + [ + EinExpr( + [ + Tensor( + ones((sizes[i] for i in [:m, :f, :q])...), + [:m, :f, :q], + ), + Tensor( + ones((sizes[i] for i in [:g, :q])...), + [:g, :q], + ), + ], + [:m, :f, :g], + ), + Tensor( + ones((sizes[i] for i in [:o, :i, :m, :c])...), + [:o, :i, :m, :c], + ), + ], + [:f, :g, :o, :i, :c], + ), + Tensor(ones((sizes[i] for i in [:f, :l, :i])...), [:f, :l, :i]), + ], + [:g, :o, :c, :l], + ), + Tensor(ones((sizes[i] for i in [:g, :n, :l, :a])...), [:g, :n, :l, :a]), + ], + [:o, :c, :n, :a], + ), + EinExpr( + [ + Tensor(ones((sizes[i] for i in [:b, :e])...), [:b, :e]), + Tensor(ones((sizes[i] for i in [:d, :b, :o])...), [:d, :b, :o]), + ], + [:e, :d, :o], + ), + ], + [:c, :n, :a, :e, :d], + ), + Tensor(ones((sizes[i] for i in [:c, :e, :h])...), [:c, :e, :h]), + ], + [:n, :a, :d, :h], + ), + Tensor(ones((sizes[i] for i in [:k, :d, :h, :a, :n, :j])...), [:k, :d, :h, :a, :n, :j]), + ], + [:k, :j], + ), + Tensor(ones((sizes[i] for i in [:p, :k])...), [:p, :k]), + ], + [:p, :j], + ) + + cuttings = findslices(FlopsScorer(), expr, slices = 1000) + + @test prod(i -> size(expr, i), cuttings) >= 1000 +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 7e4063a..88b4088 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,7 @@ using EinExprs @testset "Optimizers" verbose = true begin include("Exhaustive_test.jl") end + include("Slicing_test.jl") end @testset "Integration tests" verbose = true begin