Skip to content

Commit

Permalink
add dense 2-norm belief propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Oct 17, 2023
1 parent 8fb68cc commit d3cdd9e
Show file tree
Hide file tree
Showing 5 changed files with 1,007 additions and 0 deletions.
91 changes: 91 additions & 0 deletions quimb/experimental/belief_propagation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Belief propagation (BP) routines. There are three potential categorizations
of BP and each combination of them is potentially valid specific algorithm.
1-norm vs 2-norm BP
-------------------
- 1-norm (normal): BP runs directly on the tensor network, messages have size
``d`` where ``d`` is the size of the bond(s) connecting two tensors or
regions.
- 2-norm (quantum): BP runs on the squared tensor network, messages have size
``d^2`` where ``d`` is the size of the bond(s) connecting two tensors or
regions. Each local tensor or region is partially traced (over dangling
indices) with its conjugate to create a single node.
Graph vs Hypergraph BP
----------------------
- Graph (simple): the tensor network lives on a graph, where indices either
appear on two tensors (a bond), or appear on a single tensor (are outputs).
In this case, messages are exchanged directly between tensors.
- Hypergraph: the tensor network lives on a hypergraph, where indices can
appear on any number of tensors. In this case, the update procedure is two
parts, first all 'tensor' messages are computed, these are then used in the
second step to compute all the 'index' messages, which are then fed back into
the 'tensor' message update and so forth. For 2-norm BP one likely needs to
specify which indices are outputs and should be traced over.
The hypergraph case of course includes the graph case, but since the 'index'
message update is simply the identity, it is convenient to have a separate
simpler implementation, where the standard TN bond vs physical index
definitions hold.
Dense vs Vectorized vs Lazy BP
------------------------------
- Dense: each node is a single tensor, or pair of tensors for 2-norm BP. If all
multibonds have been fused, then each message is a vector (1-norm case) or
matrix (2-norm case).
- Vectorized: the same as the above, but all matching tensor update and message
updates are stacked and performed simultaneously. This can be enormously more
efficient for large numbers of small tensors.
- Lazy: each node is potentially a tensor network itself with arbitrary inner
structure and number of bonds connecting to other nodes. The message are
generally tensors and each update is a lazy contraction, which is potentially
much cheaper / requires less memory than forming the 'dense' node for large
tensors.
(There is also the MPS flavor where each node has a 1D structure and the
messages are matrix product states, with updates involving compression.)
Overall that gives 12 possible BP flavors, some implemented here:
- [x] (HD1BP) hyper, dense, 1-norm - this is the standard BP algorithm
- [x] (HD2BP) hyper, dense, 2-norm
- [x] (HV1BP) hyper, vectorized, 1-norm
- [ ] (HV2BP) hyper, vectorized, 2-norm
- [ ] (HL1BP) hyper, lazy, 1-norm
- [ ] (HL2BP) hyper, lazy, 2-norm
- [ ] (D1BP) simple, dense, 1-norm
- [x] (D2BP) simple, dense, 2-norm - this is the standard PEPS BP algorithm
- [ ] (V1BP) simple, vectorized, 1-norm
- [ ] (V2BP) simple, vectorized, 2-norm
- [x] (L1BP) simple, lazy, 1-norm
- [x] (L2BP) simple, lazy, 2-norm
The 2-norm methods can be used to compress bonds or estimate the 2-norm.
The 1-norm methods can be used to estimate the 1-norm, i.e. contracted value.
Both methods can be used to compute index marginals and thus perform sampling.
The vectorized methods can be extremely fast for large numbers of small
tensors, but do currently require all dimensions to match.
The dense and lazy methods can can converge messages *locally*, i.e. only
update messages adjacent to messages which have changed.
"""

from .bp_common import initialize_hyper_messages
from .d2bp import D2BP, contract_d2bp, compress_d2bp, sample_d2bp

__all__ = (
"initialize_hyper_messages",
"D2BP",
"contract_d2bp",
"compress_d2bp",
"sample_d2bp",
"HD1BP",
"HV1BP",
)
264 changes: 264 additions & 0 deletions quimb/experimental/belief_propagation/bp_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
import functools
import operator

import autoray as ar

import quimb.tensor as qtn


class RollingDiffMean:
"""Tracker for the absolute rolling mean of diffs between values, to
assess effective convergence of BP above actual message tolerance.
"""

def __init__(self, size=16):
self.size = size
self.diffs = []
self.last_x = None
self.dxsum = 0.0

def update(self, x):
if self.last_x is not None:
dx = x - self.last_x
self.diffs.append(dx)
self.dxsum += dx / self.size
if len(self.diffs) > self.size:
dx = self.diffs.pop(0)
self.dxsum -= dx / self.size
self.last_x = x

def absmeandiff(self):
if len(self.diffs) < self.size:
return float("inf")
return abs(self.dxsum)


class BeliefPropagationCommon:
"""Common interfaces for belief propagation algorithms.
Parameters
----------
max_iterations : int, optional
The maximum number of iterations to perform.
tol : float, optional
The convergence tolerance for messages.
progbar : bool, optional
Whether to show a progress bar.
"""

def run(self, max_iterations=1000, tol=5e-6, progbar=False):
if progbar:
import tqdm

pbar = tqdm.tqdm()
else:
pbar = None

try:
it = 0
rdm = RollingDiffMean()
self.converged = False
while not self.converged and it < max_iterations:
# can only converge if tol > 0.0
nconv, ncheck, max_mdiff = self.iterate(tol=tol)
rdm.update(max_mdiff)
self.converged = (max_mdiff < tol) or (rdm.absmeandiff() < tol)
it += 1

if pbar is not None:
pbar.set_description(
f"nconv: {nconv}/{ncheck} max|dM|={max_mdiff:.2e}",
refresh=False,
)
pbar.update()

finally:
if pbar is not None:
pbar.close()

if tol != 0.0 and not self.converged:
import warnings

warnings.warn(
f"Belief propagation did not converge after {max_iterations} "
f"iterations, tol={tol:.2e}, max|dM|={max_mdiff:.2e}."
)


def prod(xs):
"""Product of all elements in ``xs``."""
return functools.reduce(operator.mul, xs)


def initialize_hyper_messages(
tn,
fill_fn=None,
smudge_factor=1e-12
):
"""Initialize messages for belief propagation, this is equivalent to doing
a single round of belief propagation with uniform messages.
Parameters
----------
tn : TensorNetwork
The tensor network to initialize messages for.
fill_fn : callable, optional
A function to fill the messages with, of signature ``fill_fn(shape)``.
smudge_factor : float, optional
A small number to add to the messages to avoid numerical issues.
Returns
-------
messages : dict
The initial messages. For every index and tensor id pair, there will
be a message to and from with keys ``(ix, tid)`` and ``(tid, ix)``.
"""
from quimb.tensor.contraction import array_contract

backend = ar.infer_backend(next(t.data for t in tn))
_sum = ar.get_lib_fn(backend, "sum")

messages = {}

# compute first messages from tensors to indices
for tid, t in tn.tensor_map.items():
k_inputs = tuple(range(t.ndim))
for i, ix in enumerate(t.inds):
if fill_fn is None:
# sum over all other indices to get initial message
m = array_contract(
arrays=(t.data,),
inputs=(k_inputs,),
output=(i,),
)
# normalize and insert
messages[tid, ix] = m / _sum(m)
else:
d = t.ind_size(ix)
m = fill_fn((d,))
messages[tid, ix] = m / _sum(m)

# compute first messages from indices to tensors
for ix, tids in tn.ind_map.items():
ms = [messages[tid, ix] for tid in tids]
mp = prod(ms)
for mi, tid in zip(ms, tids):
m = mp / (mi + smudge_factor)
# normalize and insert
messages[ix, tid] = m / _sum(m)

return messages


def combine_local_contractions(
tvals,
mvals,
backend,
strip_exponent=False,
check_for_zero=True,
):
_abs = ar.get_lib_fn(backend, "abs")
_log10 = ar.get_lib_fn(backend, "log10")

mantissa = 1
exponent = 0
for vt in tvals:
avt = _abs(vt)

if check_for_zero and (avt == 0.0):
if strip_exponent:
return 0.0, 0.0
else:
return 0.0

mantissa = mantissa * (vt / avt)
exponent = exponent + _log10(avt)
for mt in mvals:
amt = _abs(mt)
mantissa = mantissa / (mt / amt)
exponent = exponent - _log10(amt)

if strip_exponent:
return mantissa, exponent
else:
return mantissa * 10**exponent


def maybe_get_thread_pool(thread_pool):
"""Get a thread pool if requested."""
if thread_pool is False:
return None

if thread_pool is True:
import quimb as qu

return qu.get_thread_pool()

if isinstance(thread_pool, int):
import quimb as qu

return qu.get_thread_pool(thread_pool)

return thread_pool


def create_lazy_community_edge_map(tn, site_tags=None, rank_simplify=True):
"""For lazy BP algorithms, create the data structures describing the
effective graph of the lazily grouped 'sites' given by ``site_tags``.
"""
if site_tags is None:
site_tags = set(tn.site_tags)
else:
site_tags = set(site_tags)

edges = {}
neighbors = {}
local_tns = {}
touch_map = {}

for ix in tn.ind_map:
ts = tn._inds_get(ix)
tags = {tag for t in ts for tag in t.tags if tag in site_tags}
if len(tags) >= 2:
i, j = tuple(sorted(tags))

if (i, j) in edges:
# already processed this edge
continue

# add to neighbor map
neighbors.setdefault(i, []).append(j)
neighbors.setdefault(j, []).append(i)

# get local TNs and compute bonds between them,
# rank simplify here also to prepare for contractions
try:
tn_i = local_tns[i]
except KeyError:
tn_i = local_tns[i] = tn.select(i, virtual=False)
if rank_simplify:
tn_i.rank_simplify_()
try:
tn_j = local_tns[j]
except KeyError:
tn_j = local_tns[j] = tn.select(j, virtual=False)
if rank_simplify:
tn_j.rank_simplify_()

edges[i, j] = tuple(qtn.bonds(tn_i, tn_j))

for i, j in edges:
touch_map[(i, j)] = tuple((j, k) for k in neighbors[j] if k != i)
touch_map[(j, i)] = tuple((i, k) for k in neighbors[i] if k != j)

if len(local_tns) != len(site_tags):
# handle potentially disconnected sites
for i in sorted(site_tags):
try:
tn_i = local_tns[i] = tn.select(i, virtual=False)
if rank_simplify:
tn_i.rank_simplify_()
except KeyError:
pass

return edges, neighbors, local_tns, touch_map
Loading

0 comments on commit d3cdd9e

Please sign in to comment.