From 0936be608348739caf74325d5a8a79d9cba8aef1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 17 Jan 2024 21:08:05 +0100 Subject: [PATCH] Implement breadth-first exhaustive search --- Project.toml | 1 + src/Optimizers/Exhaustive.jl | 116 ++++++++++++++++++++++++++++++++--- 2 files changed, 110 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index c029b56..84aab57 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" diff --git a/src/Optimizers/Exhaustive.jl b/src/Optimizers/Exhaustive.jl index fafe2b9..424ada2 100644 --- a/src/Optimizers/Exhaustive.jl +++ b/src/Optimizers/Exhaustive.jl @@ -1,5 +1,6 @@ using Base: @kwdef using Combinatorics +using LinearAlgebra: Symmetric @doc raw""" Exhaustive(; outer = false) @@ -20,16 +21,23 @@ The algorithm has a ``\mathcal{O}(n!)`` time complexity if `outer = true` and `` @kwdef struct Exhaustive <: Optimizer metric::Function = flops outer::Bool = false + strategy::Symbol = :breadth end function einexpr(config::Exhaustive, path::SizedEinExpr{L}; cost = BigInt(0)) where {L} - init_path = einexpr(Naive(), path) - leader = Ref((; - path = init_path, - cost = mapreduce(config.metric, +, Branches(init_path, inverse = true), init = BigInt(0))::BigInt, - )) - exhaustive_depthfirst(Val(config.metric), path, cost, config.outer, leader) - return leader[].path + if config.strategy === :breadth + return exhaustive_breadthfirst(Val(config.metric), path; outer = config.outer) + elseif config.strategy === :depth + init_path = einexpr(Naive(), path) + leader = Ref((; + path = init_path, + cost = mapreduce(config.metric, +, Branches(init_path, inverse = true), init = BigInt(0))::BigInt, + )) + exhaustive_depthfirst(Val(config.metric), path, cost, config.outer, leader) + return leader[].path + else + error("Unknown strategy: $(config.strategy)") + end end function exhaustive_depthfirst( @@ -60,3 +68,97 @@ function exhaustive_depthfirst( exhaustive_depthfirst(metric, new_path, new_cost, outer, leader; cache, hashyperinds) end end + +function exhaustive_breadthfirst( + @specialize(metric::Val{Metric}), + expr::SizedEinExpr{L}; + outer::Bool = false, + hashyperinds = !isempty(hyperinds(expr)), +) where {L,Metric} + outer && error("Outer products not supported yet") + hashyperinds && error("Hyperindices not supported yet") + + cost_fac = maximum(values(expr.size)) + + # make a initial guess using a fast optimizer like Greedy + greedy_path = einexpr(Greedy(), expr) + cost_max = mapreduce(Metric, +, Branches(greedy_path, inverse = true), init = BigInt(0))::BigInt + + # number of input tensors + n = nargs(expr) + + # S[c]: set of all objects made up by contracting together `c` unique tensors from S[1] + # NOTE BitSet contains identifiers (i.e. an `Integer`) of input tensors, so each set is a candidate "contracted" subgraph + # NOTE it doesn't contain all combinations (as it's combinatorially big); it's filtered by `cost_max` + S = map(_ -> BitSet[], 1:n) + + # initialize S₁ + S[1] = [sizehint!(BitSet([i]), n) for i in 1:n] + + # caches the best-known cost for constructing each object in S[c] + # NOTE no cost because no contraction on S₁ (only input tensors) + costs = Dict{BitSet,BigInt}(s => zero(BigInt) for s in S[1]) + + # contains the indices of the intermediate tensors in S + indices = Dict{BitSet,Vector{L}}(s => head(expr.args[only(s)]) for s in S[1]) + + # contains the best-known contraction tree for constructing each object in S[c] + trees = Dict{BitSet,Tuple{BitSet,BitSet}}(s => (BitSet(), BitSet()) for s in S[1]) + + cost_cur = cost_max + cost_prev = zero(cost_max) + + while cost_cur <= cost_max + cost_next = cost_max + + # construct all subsets of `c` tensors (S[c]) that fulfill cost <= cost_cur + for c in 2:n, k in 1:c÷2, (ia, ta) in enumerate(S[k]), (ib, tb) in enumerate(S[c-k]) + # special case for k = c/2 ∈ ℕ (i.e. k == c-k): `S[k] === S[c-k]` and thus, we only need `combinations(S[k], 2)` + k == c - k && ia >= ib && continue + + # if not disjoint, then ta and tb contain at least one common tensor + isdisjoint(ta, tb) || continue + + get(costs, ta ∪ tb, cost_cur) > cost_prev || continue + + # new candidate contraction + tc = ta ∪ tb # aka Q in the paper + + # compute cost of getting `tc` by contracting `ta` and `tb + shallow_expr_a = EinExpr(indices[ta]) + shallow_expr_b = EinExpr(indices[tb]) + expr_c = sum(shallow_expr_a, shallow_expr_b; skip = expr.head) + + μ = costs[ta] + costs[tb] + Metric(SizedEinExpr(expr_c, expr.size)) + + # if `μ` is the cheapest known cost for constructing `tc`, record it + if μ <= get(costs, tc, cost_cur) + tc ∉ S[c] && push!(S[c], tc) + costs[tc] = μ + indices[tc] = head(expr_c) + trees[tc] = (ta, tb) + + elseif cost_cur < μ < cost_next + cost_next = μ + end + end + + isempty(S[n]) || break + + cost_prev = cost_cur + cost_cur = min(cost_max, cost_next * cost_fac) + end + + function recurse_construct(tc) + ta, tb = trees[tc] + + if isempty(ta) && isempty(tb) + return EinExpr(indices[tc]::Vector{L}) + end + + return EinExpr(indices[tc], map(recurse_construct, [ta, tb])) + end + + path = recurse_construct(only(S[n])) + return SizedEinExpr(path, expr.size) +end