From d3cdd9ead6276bf6de256d80f4027ce452374cce Mon Sep 17 00:00:00 2001
From: Johnnie Gray <johnniemcgray@gmail.com>
Date: Tue, 17 Oct 2023 16:24:40 -0700
Subject: [PATCH] add dense 2-norm belief propagation

---
 .../belief_propagation/__init__.py            |  91 +++
 .../belief_propagation/bp_common.py           | 264 ++++++++
 quimb/experimental/belief_propagation/d2bp.py | 604 ++++++++++++++++++
 .../test_belief_propagation/__init__.py       |   0
 .../test_belief_propagation/test_d2bp.py      |  48 ++
 5 files changed, 1007 insertions(+)
 create mode 100644 quimb/experimental/belief_propagation/__init__.py
 create mode 100644 quimb/experimental/belief_propagation/bp_common.py
 create mode 100644 quimb/experimental/belief_propagation/d2bp.py
 create mode 100644 tests/test_tensor/test_belief_propagation/__init__.py
 create mode 100644 tests/test_tensor/test_belief_propagation/test_d2bp.py

diff --git a/quimb/experimental/belief_propagation/__init__.py b/quimb/experimental/belief_propagation/__init__.py
new file mode 100644
index 00000000..6582fe46
--- /dev/null
+++ b/quimb/experimental/belief_propagation/__init__.py
@@ -0,0 +1,91 @@
+"""Belief propagation (BP) routines. There are three potential categorizations
+of BP and each combination of them is potentially valid specific algorithm.
+
+1-norm vs 2-norm BP
+-------------------
+
+- 1-norm (normal): BP runs directly on the tensor network, messages have size
+  ``d`` where ``d`` is the size of the bond(s) connecting two tensors or
+  regions.
+- 2-norm (quantum): BP runs on the squared tensor network, messages have size
+  ``d^2`` where ``d`` is the size of the bond(s) connecting two tensors or
+  regions. Each local tensor or region is partially traced (over dangling
+  indices) with its conjugate to create a single node.
+
+
+Graph vs Hypergraph BP
+----------------------
+
+- Graph (simple): the tensor network lives on a graph, where indices either
+  appear on two tensors (a bond), or appear on a single tensor (are outputs).
+  In this case, messages are exchanged directly between tensors.
+- Hypergraph: the tensor network lives on a hypergraph, where indices can
+  appear on any number of tensors. In this case, the update procedure is two
+  parts, first all 'tensor' messages are computed, these are then used in the
+  second step to compute all the 'index' messages, which are then fed back into
+  the 'tensor' message update and so forth. For 2-norm BP one likely needs to
+  specify which indices are outputs and should be traced over.
+
+The hypergraph case of course includes the graph case, but since the 'index'
+message update is simply the identity, it is convenient to have a separate
+simpler implementation, where the standard TN bond vs physical index
+definitions hold.
+
+
+Dense vs Vectorized vs Lazy BP
+------------------------------
+
+- Dense: each node is a single tensor, or pair of tensors for 2-norm BP. If all
+  multibonds have been fused, then each message is a vector (1-norm case) or
+  matrix (2-norm case).
+- Vectorized: the same as the above, but all matching tensor update and message
+  updates are stacked and performed simultaneously. This can be enormously more
+  efficient for large numbers of small tensors.
+- Lazy: each node is potentially a tensor network itself with arbitrary inner
+  structure and number of bonds connecting to other nodes. The message are
+  generally tensors and each update is a lazy contraction, which is potentially
+  much cheaper / requires less memory than forming the 'dense' node for large
+  tensors.
+
+(There is also the MPS flavor where each node has a 1D structure and the
+messages are matrix product states, with updates involving compression.)
+
+
+Overall that gives 12 possible BP flavors, some implemented here:
+
+- [x] (HD1BP) hyper, dense, 1-norm - this is the standard BP algorithm
+- [x] (HD2BP) hyper, dense, 2-norm
+- [x] (HV1BP) hyper, vectorized, 1-norm
+- [ ] (HV2BP) hyper, vectorized, 2-norm
+- [ ] (HL1BP) hyper, lazy, 1-norm
+- [ ] (HL2BP) hyper, lazy, 2-norm
+- [ ] (D1BP) simple, dense, 1-norm
+- [x] (D2BP) simple, dense, 2-norm - this is the standard PEPS BP algorithm
+- [ ] (V1BP) simple, vectorized, 1-norm
+- [ ] (V2BP) simple, vectorized, 2-norm
+- [x] (L1BP) simple, lazy, 1-norm
+- [x] (L2BP) simple, lazy, 2-norm
+
+The 2-norm methods can be used to compress bonds or estimate the 2-norm.
+The 1-norm methods can be used to estimate the 1-norm, i.e. contracted value.
+Both methods can be used to compute index marginals and thus perform sampling.
+
+The vectorized methods can be extremely fast for large numbers of small
+tensors, but do currently require all dimensions to match.
+
+The dense and lazy methods can can converge messages *locally*, i.e. only
+update messages adjacent to messages which have changed.
+"""
+
+from .bp_common import initialize_hyper_messages
+from .d2bp import D2BP, contract_d2bp, compress_d2bp, sample_d2bp
+
+__all__ = (
+    "initialize_hyper_messages",
+    "D2BP",
+    "contract_d2bp",
+    "compress_d2bp",
+    "sample_d2bp",
+    "HD1BP",
+    "HV1BP",
+)
diff --git a/quimb/experimental/belief_propagation/bp_common.py b/quimb/experimental/belief_propagation/bp_common.py
new file mode 100644
index 00000000..c8cf9b47
--- /dev/null
+++ b/quimb/experimental/belief_propagation/bp_common.py
@@ -0,0 +1,264 @@
+import functools
+import operator
+
+import autoray as ar
+
+import quimb.tensor as qtn
+
+
+class RollingDiffMean:
+    """Tracker for the absolute rolling mean of diffs between values, to
+    assess effective convergence of BP above actual message tolerance.
+    """
+
+    def __init__(self, size=16):
+        self.size = size
+        self.diffs = []
+        self.last_x = None
+        self.dxsum = 0.0
+
+    def update(self, x):
+        if self.last_x is not None:
+            dx = x - self.last_x
+            self.diffs.append(dx)
+            self.dxsum += dx / self.size
+        if len(self.diffs) > self.size:
+            dx = self.diffs.pop(0)
+            self.dxsum -= dx / self.size
+        self.last_x = x
+
+    def absmeandiff(self):
+        if len(self.diffs) < self.size:
+            return float("inf")
+        return abs(self.dxsum)
+
+
+class BeliefPropagationCommon:
+    """Common interfaces for belief propagation algorithms.
+
+    Parameters
+    ----------
+    max_iterations : int, optional
+        The maximum number of iterations to perform.
+    tol : float, optional
+        The convergence tolerance for messages.
+    progbar : bool, optional
+        Whether to show a progress bar.
+    """
+
+    def run(self, max_iterations=1000, tol=5e-6, progbar=False):
+        if progbar:
+            import tqdm
+
+            pbar = tqdm.tqdm()
+        else:
+            pbar = None
+
+        try:
+            it = 0
+            rdm = RollingDiffMean()
+            self.converged = False
+            while not self.converged and it < max_iterations:
+                # can only converge if tol > 0.0
+                nconv, ncheck, max_mdiff = self.iterate(tol=tol)
+                rdm.update(max_mdiff)
+                self.converged = (max_mdiff < tol) or (rdm.absmeandiff() < tol)
+                it += 1
+
+                if pbar is not None:
+                    pbar.set_description(
+                        f"nconv: {nconv}/{ncheck} max|dM|={max_mdiff:.2e}",
+                        refresh=False,
+                    )
+                    pbar.update()
+
+        finally:
+            if pbar is not None:
+                pbar.close()
+
+        if tol != 0.0 and not self.converged:
+            import warnings
+
+            warnings.warn(
+                f"Belief propagation did not converge after {max_iterations} "
+                f"iterations, tol={tol:.2e}, max|dM|={max_mdiff:.2e}."
+            )
+
+
+def prod(xs):
+    """Product of all elements in ``xs``."""
+    return functools.reduce(operator.mul, xs)
+
+
+def initialize_hyper_messages(
+    tn,
+    fill_fn=None,
+    smudge_factor=1e-12
+):
+    """Initialize messages for belief propagation, this is equivalent to doing
+    a single round of belief propagation with uniform messages.
+
+    Parameters
+    ----------
+    tn : TensorNetwork
+        The tensor network to initialize messages for.
+    fill_fn : callable, optional
+        A function to fill the messages with, of signature ``fill_fn(shape)``.
+    smudge_factor : float, optional
+        A small number to add to the messages to avoid numerical issues.
+
+    Returns
+    -------
+    messages : dict
+        The initial messages. For every index and tensor id pair, there will
+        be a message to and from with keys ``(ix, tid)`` and ``(tid, ix)``.
+    """
+    from quimb.tensor.contraction import array_contract
+
+    backend = ar.infer_backend(next(t.data for t in tn))
+    _sum = ar.get_lib_fn(backend, "sum")
+
+    messages = {}
+
+    # compute first messages from tensors to indices
+    for tid, t in tn.tensor_map.items():
+        k_inputs = tuple(range(t.ndim))
+        for i, ix in enumerate(t.inds):
+            if fill_fn is None:
+                # sum over all other indices to get initial message
+                m = array_contract(
+                    arrays=(t.data,),
+                    inputs=(k_inputs,),
+                    output=(i,),
+                )
+                # normalize and insert
+                messages[tid, ix] = m / _sum(m)
+            else:
+                d = t.ind_size(ix)
+                m = fill_fn((d,))
+                messages[tid, ix] = m / _sum(m)
+
+    # compute first messages from indices to tensors
+    for ix, tids in tn.ind_map.items():
+        ms = [messages[tid, ix] for tid in tids]
+        mp = prod(ms)
+        for mi, tid in zip(ms, tids):
+            m = mp / (mi + smudge_factor)
+            # normalize and insert
+            messages[ix, tid] = m / _sum(m)
+
+    return messages
+
+
+def combine_local_contractions(
+    tvals,
+    mvals,
+    backend,
+    strip_exponent=False,
+    check_for_zero=True,
+):
+    _abs = ar.get_lib_fn(backend, "abs")
+    _log10 = ar.get_lib_fn(backend, "log10")
+
+    mantissa = 1
+    exponent = 0
+    for vt in tvals:
+        avt = _abs(vt)
+
+        if check_for_zero and (avt == 0.0):
+            if strip_exponent:
+                return 0.0, 0.0
+            else:
+                return 0.0
+
+        mantissa = mantissa * (vt / avt)
+        exponent = exponent + _log10(avt)
+    for mt in mvals:
+        amt = _abs(mt)
+        mantissa = mantissa / (mt / amt)
+        exponent = exponent - _log10(amt)
+
+    if strip_exponent:
+        return mantissa, exponent
+    else:
+        return mantissa * 10**exponent
+
+
+def maybe_get_thread_pool(thread_pool):
+    """Get a thread pool if requested."""
+    if thread_pool is False:
+        return None
+
+    if thread_pool is True:
+        import quimb as qu
+
+        return qu.get_thread_pool()
+
+    if isinstance(thread_pool, int):
+        import quimb as qu
+
+        return qu.get_thread_pool(thread_pool)
+
+    return thread_pool
+
+
+def create_lazy_community_edge_map(tn, site_tags=None, rank_simplify=True):
+    """For lazy BP algorithms, create the data structures describing the
+    effective graph of the lazily grouped 'sites' given by ``site_tags``.
+    """
+    if site_tags is None:
+        site_tags = set(tn.site_tags)
+    else:
+        site_tags = set(site_tags)
+
+    edges = {}
+    neighbors = {}
+    local_tns = {}
+    touch_map = {}
+
+    for ix in tn.ind_map:
+        ts = tn._inds_get(ix)
+        tags = {tag for t in ts for tag in t.tags if tag in site_tags}
+        if len(tags) >= 2:
+            i, j = tuple(sorted(tags))
+
+            if (i, j) in edges:
+                # already processed this edge
+                continue
+
+            # add to neighbor map
+            neighbors.setdefault(i, []).append(j)
+            neighbors.setdefault(j, []).append(i)
+
+            # get local TNs and compute bonds between them,
+            # rank simplify here also to prepare for contractions
+            try:
+                tn_i = local_tns[i]
+            except KeyError:
+                tn_i = local_tns[i] = tn.select(i, virtual=False)
+                if rank_simplify:
+                    tn_i.rank_simplify_()
+            try:
+                tn_j = local_tns[j]
+            except KeyError:
+                tn_j = local_tns[j] = tn.select(j, virtual=False)
+                if rank_simplify:
+                    tn_j.rank_simplify_()
+
+            edges[i, j] = tuple(qtn.bonds(tn_i, tn_j))
+
+    for i, j in edges:
+        touch_map[(i, j)] = tuple((j, k) for k in neighbors[j] if k != i)
+        touch_map[(j, i)] = tuple((i, k) for k in neighbors[i] if k != j)
+
+    if len(local_tns) != len(site_tags):
+        # handle potentially disconnected sites
+        for i in sorted(site_tags):
+            try:
+                tn_i = local_tns[i] = tn.select(i, virtual=False)
+                if rank_simplify:
+                    tn_i.rank_simplify_()
+            except KeyError:
+                pass
+
+    return edges, neighbors, local_tns, touch_map
diff --git a/quimb/experimental/belief_propagation/d2bp.py b/quimb/experimental/belief_propagation/d2bp.py
new file mode 100644
index 00000000..681d1870
--- /dev/null
+++ b/quimb/experimental/belief_propagation/d2bp.py
@@ -0,0 +1,604 @@
+import autoray as ar
+import quimb.tensor as qtn
+
+from .bp_common import (
+    BeliefPropagationCommon,
+    combine_local_contractions,
+)
+
+
+class D2BP(BeliefPropagationCommon):
+    """Dense (as in one tensor per site) 2-norm (as in for wavefunctions and
+    operators) belief propagation. Allows messages reuse. This version assumes
+    no hyper indices (i.e. a standard PEPS like tensor network).
+
+    Parameters
+    ----------
+    tn : TensorNetwork
+        The tensor network to form the 2-norm of and run BP on.
+    messages : dict[(str, int), array_like], optional
+        The initial messages to use, effectively defaults to all ones if not
+        specified.
+    output_inds : set[str], optional
+        The indices to consider as output (dangling) indices of the tn.
+        Computed automatically if not specified.
+    optimize : str or PathOptimizer, optional
+        The path optimizer to use when contracting the messages.
+    local_convergence : bool, optional
+        Whether to allow messages to locally converge - i.e. if all their
+        input messages have converged then stop updating them.
+    contract_opts
+        Other options supplied to ``cotengra.array_contract``.
+    """
+
+    def __init__(
+        self,
+        tn,
+        messages=None,
+        output_inds=None,
+        optimize="auto-hq",
+        local_convergence=True,
+        damping=0.0,
+        **contract_opts,
+    ):
+        from quimb.tensor.contraction import array_contract_expression
+
+        self.tn = tn
+        self.contract_opts = contract_opts
+        self.contract_opts.setdefault("optimize", optimize)
+        self.local_convergence = local_convergence
+        self.damping = damping
+
+        if output_inds is None:
+            self.output_inds = set(self.tn.outer_inds())
+        else:
+            self.output_inds = set(output_inds)
+
+        self.backend = next(t.backend for t in tn)
+        _abs = ar.get_lib_fn(self.backend, "abs")
+        _sum = ar.get_lib_fn(self.backend, "sum")
+
+        def _normalize(x):
+            return x / _sum(x)
+
+        def _distance(x, y):
+            return _sum(_abs(x - y))
+
+        self._normalize = _normalize
+        self._distance = _distance
+
+        if messages is None:
+            self.messages = {}
+        else:
+            self.messages = messages
+
+        # record which messages touch each others, for efficient updates
+        self.touch_map = {}
+        self.touched = set()
+        self.exprs = {}
+
+        # populate any messages
+        for ix, tids in self.tn.ind_map.items():
+            if ix in self.output_inds:
+                continue
+
+            tida, tidb = tids
+            jx = ix + "*"
+            ta, tb = self.tn._tids_get(tida, tidb)
+
+            for tid, t, t_in in ((tida, ta, tb), (tidb, tb, ta)):
+                this_touchmap = []
+                for nx in t.inds:
+                    if nx in self.output_inds or nx == ix:
+                        continue
+                    # where this message will be sent on to
+                    (tidn,) = (n for n in self.tn.ind_map[nx] if n != tid)
+                    this_touchmap.append((nx, tidn))
+                self.touch_map[ix, tid] = this_touchmap
+
+                if (ix, tid) not in self.messages:
+                    tm = (t_in.reindex({ix: jx}).conj_() @ t_in).data
+                    self.messages[ix, tid] = self._normalize(tm.data)
+
+        # for efficiency setup all the contraction expressions ahead of time
+        for ix, tids in self.tn.ind_map.items():
+            if ix in self.output_inds:
+                continue
+
+            for tida, tidb in (sorted(tids), sorted(tids, reverse=True)):
+                ta = self.tn.tensor_map[tida]
+                kix = ta.inds
+                bix = tuple(
+                    i if i in self.output_inds else i + "*" for i in kix
+                )
+                inputs = [kix, bix]
+                data = [ta.data, ta.data.conj()]
+                shapes = [ta.shape, ta.shape]
+                for i in kix:
+                    if (i != ix) and i not in self.output_inds:
+                        inputs.append((i + "*", i))
+                        data.append((i, tida))
+                        shapes.append(self.messages[i, tida].shape)
+
+                expr = array_contract_expression(
+                    inputs=inputs,
+                    output=(ix + "*", ix),
+                    shapes=shapes,
+                    **self.contract_opts,
+                )
+                self.exprs[ix, tidb] = expr, data
+
+    def update_touched_from_tids(self, *tids):
+        """Specify that the messages for the given ``tids`` have changed."""
+        for tid in tids:
+            t = self.tn.tensor_map[tid]
+            for ix in t.inds:
+                if ix in self.output_inds:
+                    continue
+                (ntid,) = (n for n in self.tn.ind_map[ix] if n != tid)
+                self.touched.add((ix, ntid))
+
+    def update_touched_from_tags(self, tags, which="any"):
+        """Specify that the messages for the messages touching ``tags`` have
+        changed.
+        """
+        tids = self.tn._get_tids_from_tags(tags, which)
+        self.update_touched_from_tids(*tids)
+
+    def update_touched_from_inds(self, inds, which="any"):
+        """Specify that the messages for the messages touching ``inds`` have
+        changed.
+        """
+        tids = self.tn._get_tids_from_inds(inds, which)
+        self.update_touched_from_tids(*tids)
+
+    def iterate(self, tol=5e-6):
+        """Perform a single iteration of dense 2-norm belief propagation."""
+
+        if (not self.local_convergence) or (not self.touched):
+            # assume if asked to iterate that we want to check all messages
+            self.touched.update(self.exprs.keys())
+
+        ncheck = len(self.touched)
+        new_messages = {}
+        while self.touched:
+            key = self.touched.pop()
+            expr, data = self.exprs[key]
+            m = expr(*data[:2], *(self.messages[mkey] for mkey in data[2:]))
+            # enforce hermiticity and normalize
+            m = m + ar.dag(m)
+            m = self._normalize(m)
+
+            if self.damping > 0.0:
+                m = self._normalize(
+                    # new message
+                    (1 - self.damping) * m
+                    +
+                    # old message
+                    self.damping * self.messages[key]
+                )
+
+            new_messages[key] = m
+
+        # process modified messages
+        nconv = 0
+        max_mdiff = -1.0
+        for key, m in new_messages.items():
+            mdiff = float(self._distance(m, self.messages[key]))
+
+            if mdiff > tol:
+                # mark touching messages for update
+                self.touched.update(self.touch_map[key])
+            else:
+                nconv += 1
+
+            max_mdiff = max(max_mdiff, mdiff)
+            self.messages[key] = m
+
+        return nconv, ncheck, max_mdiff
+
+    def compute_marginal(self, ind):
+        """Compute the marginal for the index ``ind``."""
+        (tid,) = self.tn.ind_map[ind]
+        t = self.tn.tensor_map[tid]
+
+        arrays = [t.data, ar.do("conj", t.data)]
+        k_input = []
+        b_input = []
+        m_inputs = []
+        for j, jx in enumerate(t.inds, 1):
+            k_input.append(j)
+
+            if jx == ind:
+                # output index -> take diagonal
+                output = (j,)
+                b_input.append(j)
+            else:
+                try:
+                    # partial trace with message
+                    m = self.messages[jx, tid]
+                    arrays.append(m)
+                    b_input.append(-j)
+                    m_inputs.append((-j, j))
+                except KeyError:
+                    # direct partial trace
+                    b_input.append(j)
+
+        p = qtn.array_contract(
+            arrays,
+            inputs=(tuple(k_input), tuple(b_input), *m_inputs),
+            output=output,
+            **self.contract_opts,
+        )
+        p = ar.do("real", p)
+        return p / ar.do("sum", p)
+
+    def contract(self, strip_exponent=False):
+        """Estimate the total contraction, i.e. the 2-norm.
+
+        Parameters
+        ----------
+        strip_exponent : bool, optional
+            Whether to strip the exponent from the final result. If ``True``
+            then the returned result is ``(mantissa, exponent)``.
+
+        Returns
+        -------
+        scalar or (scalar, float)
+        """
+        tvals = []
+
+        for tid, t in self.tn.tensor_map.items():
+            arrays = [t.data, ar.do("conj", t.data)]
+            k_input = []
+            b_input = []
+            m_inputs = []
+            for i, ix in enumerate(t.inds, 1):
+                k_input.append(i)
+                if ix in self.output_inds:
+                    b_input.append(i)
+                else:
+                    b_input.append(-i)
+                    m_inputs.append((-i, i))
+                    arrays.append(self.messages[ix, tid])
+
+            inputs = (tuple(k_input), tuple(b_input), *m_inputs)
+            output = ()
+            tval = qtn.array_contract(
+                arrays, inputs, output, **self.contract_opts
+            )
+            tvals.append(tval)
+
+        mvals = []
+        for ix, tids in self.tn.ind_map.items():
+            if ix in self.output_inds:
+                continue
+            tida, tidb = tids
+            ml = self.messages[ix, tidb]
+            mr = self.messages[ix, tida]
+            mval = qtn.array_contract(
+                (ml, mr), ((1, 2), (1, 2)), (), **self.contract_opts
+            )
+            mvals.append(mval)
+
+        return combine_local_contractions(
+            tvals, mvals, self.backend, strip_exponent=strip_exponent
+        )
+
+    def compress(
+        self,
+        max_bond,
+        cutoff=0.0,
+        cutoff_mode=4,
+        renorm=0,
+        inplace=False,
+    ):
+        """Compress the initial tensor network using the current messages."""
+        tn = self.tn if inplace else self.tn.copy()
+
+        for ix, tids in tn.ind_map.items():
+            if len(tids) != 2:
+                continue
+            tida, tidb = tids
+
+            # messages are left and right factors squared already
+            ta = tn.tensor_map[tida]
+            dm = ta.ind_size(ix)
+            dl = ta.size // dm
+            ml = self.messages[ix, tidb]
+            Rl = qtn.decomp.squared_op_to_reduced_factor(
+                ml, dl, dm, right=True
+            )
+
+            tb = tn.tensor_map[tidb]
+            dr = tb.size // dm
+            mr = self.messages[ix, tida].T
+            Rr = qtn.decomp.squared_op_to_reduced_factor(
+                mr, dm, dr, right=False
+            )
+
+            # compute the compressors
+            Pl, Pr = qtn.decomp.compute_oblique_projectors(
+                Rl,
+                Rr,
+                max_bond=max_bond,
+                cutoff=cutoff,
+                cutoff_mode=cutoff_mode,
+                renorm=renorm,
+            )
+
+            # contract the compressors into the tensors
+            tn.tensor_map[tida].gate_(Pl.T, ix)
+            tn.tensor_map[tidb].gate_(Pr, ix)
+
+            # update messages with projections
+            if inplace:
+                new_Ra = Rl @ Pl
+                new_Rb = Pr @ Rr
+                self.messages[ix, tidb] = ar.dag(new_Ra) @ new_Ra
+                self.messages[ix, tida] = new_Rb @ ar.dag(new_Rb)
+
+        return tn
+
+
+def contract_d2bp(
+    tn,
+    messages=None,
+    output_inds=None,
+    optimize="auto-hq",
+    local_convergence=True,
+    damping=0.0,
+    max_iterations=1000,
+    tol=5e-6,
+    strip_exponent=False,
+    progbar=False,
+    **contract_opts,
+):
+    """Estimate the norm squared of ``tn`` using dense 2-norm belief
+    propagation.
+
+    Parameters
+    ----------
+    tn : TensorNetwork
+        The tensor network to form the 2-norm of and run BP on.
+    messages : dict[(str, int), array_like], optional
+        The initial messages to use, effectively defaults to all ones if not
+        specified.
+    max_iterations : int, optional
+        The maximum number of iterations to perform.
+    tol : float, optional
+        The convergence tolerance for messages.
+    output_inds : set[str], optional
+        The indices to consider as output (dangling) indices of the tn.
+        Computed automatically if not specified.
+    optimize : str or PathOptimizer, optional
+        The path optimizer to use when contracting the messages.
+    local_convergence : bool, optional
+        Whether to allow messages to locally converge - i.e. if all their
+        input messages have converged then stop updating them.
+    damping : float, optional
+        The damping parameter to use, defaults to no damping.
+    strip_exponent : bool, optional
+        Whether to strip the exponent from the final result. If ``True``
+        then the returned result is ``(mantissa, exponent)``.
+    progbar : bool, optional
+        Whether to show a progress bar.
+    contract_opts
+        Other options supplied to ``cotengra.array_contract``.
+
+    Returns
+    -------
+    scalar or (scalar, float)
+    """
+    bp = D2BP(
+        tn,
+        messages=messages,
+        output_inds=output_inds,
+        optimize=optimize,
+        local_convergence=local_convergence,
+        damping=damping,
+        **contract_opts,
+    )
+    bp.run(
+        max_iterations=max_iterations,
+        tol=tol,
+        progbar=progbar,
+    )
+    return bp.contract(strip_exponent=strip_exponent)
+
+
+def compress_d2bp(
+    tn,
+    max_bond,
+    cutoff=0.0,
+    cutoff_mode="rsum2",
+    renorm=0,
+    messages=None,
+    output_inds=None,
+    optimize="auto-hq",
+    local_convergence=True,
+    damping=0.0,
+    max_iterations=1000,
+    tol=5e-6,
+    inplace=False,
+    progbar=False,
+    **contract_opts,
+):
+    """Compress the tensor network ``tn`` using dense 2-norm belief
+    propagation.
+
+    Parameters
+    ----------
+    tn : TensorNetwork
+        The tensor network to form the 2-norm of, run BP on and then compress.
+    max_bond : int
+        The maximum bond dimension to compress to.
+    cutoff : float, optional
+        The cutoff to use when compressing.
+    cutoff_mode : int, optional
+        The cutoff mode to use when compressing.
+    messages : dict[(str, int), array_like], optional
+        The initial messages to use, effectively defaults to all ones if not
+        specified.
+    max_iterations : int, optional
+        The maximum number of iterations to perform.
+    tol : float, optional
+        The convergence tolerance for messages.
+    output_inds : set[str], optional
+        The indices to consider as output (dangling) indices of the tn.
+        Computed automatically if not specified.
+    optimize : str or PathOptimizer, optional
+        The path optimizer to use when contracting the messages.
+    local_convergence : bool, optional
+        Whether to allow messages to locally converge - i.e. if all their
+        input messages have converged then stop updating them.
+    damping : float, optional
+        The damping parameter to use, defaults to no damping.
+    inplace : bool, optional
+        Whether to perform the compression inplace.
+    progbar : bool, optional
+        Whether to show a progress bar.
+    contract_opts
+        Other options supplied to ``cotengra.array_contract``.
+
+    Returns
+    -------
+    TensorNetwork
+    """
+    bp = D2BP(
+        tn,
+        messages=messages,
+        output_inds=output_inds,
+        optimize=optimize,
+        local_convergence=local_convergence,
+        damping=damping,
+        **contract_opts,
+    )
+    bp.run(
+        max_iterations=max_iterations,
+        tol=tol,
+        progbar=progbar,
+    )
+    return bp.compress(
+        max_bond=max_bond,
+        cutoff=cutoff,
+        cutoff_mode=cutoff_mode,
+        renorm=renorm,
+        inplace=inplace,
+    )
+
+
+def sample_d2bp(
+    tn,
+    output_inds=None,
+    messages=None,
+    max_iterations=100,
+    tol=1e-2,
+    bias=None,
+    seed=None,
+    local_convergence=True,
+    progbar=False,
+    **contract_opts,
+):
+    """Sample a configuration from ``tn`` using dense 2-norm belief
+    propagation.
+
+    Parameters
+    ----------
+    tn : TensorNetwork
+        The tensor network to sample from.
+    output_inds : set[str], optional
+        Which indices to sample.
+    messages : dict[(str, int), array_like], optional
+        The initial messages to use, effectively defaults to all ones if not
+        specified.
+    max_iterations : int, optional
+        The maximum number of iterations to perform, per marginal.
+    tol : float, optional
+        The convergence tolerance for messages.
+    bias : float, optional
+        Bias the sampling towards more locally likely bit-strings. This is
+        done by raising the probability of each bit-string to this power.
+    seed : int, optional
+        A random seed for reproducibility.
+    local_convergence : bool, optional
+        Whether to allow messages to locally converge - i.e. if all their
+        input messages have converged then stop updating them.
+    progbar : bool, optional
+        Whether to show a progress bar.
+    contract_opts
+        Other options supplied to ``cotengra.array_contract``.
+
+    Returns
+    -------
+    config : dict[str, int]
+        The sampled configuration, a mapping of output indices to values.
+    tn_config : TensorNetwork
+        The tensor network with the sampled configuration applied.
+    omega : float
+        The BP probability of the sampled configuration.
+    """
+    import numpy as np
+
+    if output_inds is None:
+        output_inds = tn.outer_inds()
+
+    rng = np.random.default_rng(seed)
+    config = {}
+    omega = 1.0
+
+    tn = tn.copy()
+    bp = D2BP(
+        tn,
+        messages=messages,
+        local_convergence=local_convergence,
+        **contract_opts,
+    )
+    bp.run(max_iterations=max_iterations, tol=tol)
+
+    marginals = dict.fromkeys(output_inds)
+
+    if progbar:
+        import tqdm
+
+        pbar = tqdm.tqdm(total=len(marginals))
+    else:
+        pbar = None
+
+    while marginals:
+        for ix in marginals:
+            marginals[ix] = bp.compute_marginal(ix)
+
+        ix, p = max(marginals.items(), key=lambda x: max(x[1]))
+        p = ar.to_numpy(p)
+
+        if bias is not None:
+            # bias distribution towards more locally likely bit-strings
+            p = p**bias
+            p /= np.sum(p)
+
+        v = rng.choice([0, 1], p=p)
+        config[ix] = v
+        del marginals[ix]
+
+        tids = tuple(tn.ind_map[ix])
+        tn.isel_({ix: v})
+
+        omega *= p[v]
+        if progbar:
+            pbar.update(1)
+            pbar.set_description(f"{ix}->{v}", refresh=False)
+
+        bp = D2BP(
+            tn,
+            messages=bp.messages,
+            local_convergence=local_convergence,
+            **contract_opts,
+        )
+        bp.update_touched_from_tids(*tids)
+        bp.run(tol=tol, max_iterations=max_iterations)
+
+    if progbar:
+        pbar.close()
+
+    return config, tn, omega
diff --git a/tests/test_tensor/test_belief_propagation/__init__.py b/tests/test_tensor/test_belief_propagation/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/test_tensor/test_belief_propagation/test_d2bp.py b/tests/test_tensor/test_belief_propagation/test_d2bp.py
new file mode 100644
index 00000000..0bdd9b06
--- /dev/null
+++ b/tests/test_tensor/test_belief_propagation/test_d2bp.py
@@ -0,0 +1,48 @@
+import pytest
+
+import quimb.tensor as qtn
+from quimb.experimental.belief_propagation.d2bp import (
+    contract_d2bp,
+    compress_d2bp,
+    sample_d2bp,
+)
+
+
+@pytest.mark.parametrize("damping", [0.0, 0.1, 0.5])
+def test_contract(damping):
+    peps = qtn.PEPS.rand(3, 4, 3, seed=42)
+    # normalize exactly
+    peps /= (peps.H @ peps) ** 0.5
+    N_ap = contract_d2bp(peps, damping=damping)
+    assert N_ap == pytest.approx(1.0, rel=0.3)
+
+
+@pytest.mark.parametrize("damping", [0.0, 0.1, 0.5])
+def test_compress(damping):
+    peps = qtn.PEPS.rand(3, 4, 3, seed=42)
+    # test that using the BP compression gives better fidelity than purely
+    # local, naive compression scheme
+    peps_c1 = peps.compress_all(max_bond=2)
+    peps_c2 = compress_d2bp(peps, max_bond=2, damping=damping)
+    fid1 = peps_c1.H @ peps_c2
+    fid2 = peps_c2.H @ peps_c2
+    assert fid2 > fid1
+
+
+def test_sample():
+    peps = qtn.PEPS.rand(3, 4, 3, seed=42)
+    # normalize exactly
+    peps /= (peps.H @ peps) ** 0.5
+    config, peps_config, omega = sample_d2bp(peps, seed=42)
+    assert all(ix in config for ix in peps.site_inds)
+    assert 0.0 < omega < 1.0
+    assert peps_config.outer_inds() == ()
+
+    ptotal = 0.0
+    nrepeat = 4
+    for _ in range(nrepeat):
+        _, peps_config, _ = sample_d2bp(peps, seed=42)
+        ptotal += peps_config.contract()**2
+
+    # check we are doing better than random guessing
+    assert ptotal > nrepeat * 2**-peps.nsites