Skip to content

Commit

Permalink
Implement breadth-first exhaustive search
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Jan 17, 2024
1 parent 2695edd commit 0936be6
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 7 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Expand Down
116 changes: 109 additions & 7 deletions src/Optimizers/Exhaustive.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Base: @kwdef
using Combinatorics
using LinearAlgebra: Symmetric

@doc raw"""
Exhaustive(; outer = false)
Expand All @@ -20,16 +21,23 @@ The algorithm has a ``\mathcal{O}(n!)`` time complexity if `outer = true` and ``
@kwdef struct Exhaustive <: Optimizer
metric::Function = flops
outer::Bool = false
strategy::Symbol = :breadth
end

function einexpr(config::Exhaustive, path::SizedEinExpr{L}; cost = BigInt(0)) where {L}
init_path = einexpr(Naive(), path)
leader = Ref((;
path = init_path,
cost = mapreduce(config.metric, +, Branches(init_path, inverse = true), init = BigInt(0))::BigInt,
))
exhaustive_depthfirst(Val(config.metric), path, cost, config.outer, leader)
return leader[].path
if config.strategy === :breadth
return exhaustive_breadthfirst(Val(config.metric), path; outer = config.outer)
elseif config.strategy === :depth
init_path = einexpr(Naive(), path)
leader = Ref((;
path = init_path,
cost = mapreduce(config.metric, +, Branches(init_path, inverse = true), init = BigInt(0))::BigInt,
))
exhaustive_depthfirst(Val(config.metric), path, cost, config.outer, leader)
return leader[].path
else
error("Unknown strategy: $(config.strategy)")
end
end

function exhaustive_depthfirst(
Expand Down Expand Up @@ -60,3 +68,97 @@ function exhaustive_depthfirst(
exhaustive_depthfirst(metric, new_path, new_cost, outer, leader; cache, hashyperinds)
end
end

function exhaustive_breadthfirst(
@specialize(metric::Val{Metric}),
expr::SizedEinExpr{L};
outer::Bool = false,
hashyperinds = !isempty(hyperinds(expr)),
) where {L,Metric}
outer && error("Outer products not supported yet")
hashyperinds && error("Hyperindices not supported yet")

cost_fac = maximum(values(expr.size))

# make a initial guess using a fast optimizer like Greedy
greedy_path = einexpr(Greedy(), expr)
cost_max = mapreduce(Metric, +, Branches(greedy_path, inverse = true), init = BigInt(0))::BigInt

# number of input tensors
n = nargs(expr)

# S[c]: set of all objects made up by contracting together `c` unique tensors from S[1]
# NOTE BitSet contains identifiers (i.e. an `Integer`) of input tensors, so each set is a candidate "contracted" subgraph
# NOTE it doesn't contain all combinations (as it's combinatorially big); it's filtered by `cost_max`
S = map(_ -> BitSet[], 1:n)

# initialize S₁
S[1] = [sizehint!(BitSet([i]), n) for i in 1:n]

# caches the best-known cost for constructing each object in S[c]
# NOTE no cost because no contraction on S₁ (only input tensors)
costs = Dict{BitSet,BigInt}(s => zero(BigInt) for s in S[1])

# contains the indices of the intermediate tensors in S
indices = Dict{BitSet,Vector{L}}(s => head(expr.args[only(s)]) for s in S[1])

# contains the best-known contraction tree for constructing each object in S[c]
trees = Dict{BitSet,Tuple{BitSet,BitSet}}(s => (BitSet(), BitSet()) for s in S[1])

cost_cur = cost_max
cost_prev = zero(cost_max)

while cost_cur <= cost_max
cost_next = cost_max

# construct all subsets of `c` tensors (S[c]) that fulfill cost <= cost_cur
for c in 2:n, k in 1:c÷2, (ia, ta) in enumerate(S[k]), (ib, tb) in enumerate(S[c-k])
# special case for k = c/2 ∈ ℕ (i.e. k == c-k): `S[k] === S[c-k]` and thus, we only need `combinations(S[k], 2)`
k == c - k && ia >= ib && continue

# if not disjoint, then ta and tb contain at least one common tensor
isdisjoint(ta, tb) || continue

get(costs, ta tb, cost_cur) > cost_prev || continue

# new candidate contraction
tc = ta tb # aka Q in the paper

# compute cost of getting `tc` by contracting `ta` and `tb
shallow_expr_a = EinExpr(indices[ta])
shallow_expr_b = EinExpr(indices[tb])
expr_c = sum(shallow_expr_a, shallow_expr_b; skip = expr.head)

μ = costs[ta] + costs[tb] + Metric(SizedEinExpr(expr_c, expr.size))

# if `μ` is the cheapest known cost for constructing `tc`, record it
if μ <= get(costs, tc, cost_cur)
tc S[c] && push!(S[c], tc)
costs[tc] = μ
indices[tc] = head(expr_c)
trees[tc] = (ta, tb)

elseif cost_cur < μ < cost_next
cost_next = μ
end
end

isempty(S[n]) || break

cost_prev = cost_cur
cost_cur = min(cost_max, cost_next * cost_fac)
end

function recurse_construct(tc)
ta, tb = trees[tc]

if isempty(ta) && isempty(tb)
return EinExpr(indices[tc]::Vector{L})
end

return EinExpr(indices[tc], map(recurse_construct, [ta, tb]))
end

path = recurse_construct(only(S[n]))
return SizedEinExpr(path, expr.size)
end

0 comments on commit 0936be6

Please sign in to comment.