Skip to content

Commit

Permalink
add lazy 1-norm belief propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Oct 20, 2023
1 parent 11d1c3a commit 8e6e290
Show file tree
Hide file tree
Showing 3 changed files with 358 additions and 0 deletions.
8 changes: 8 additions & 0 deletions quimb/experimental/belief_propagation/d2bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ class D2BP(BeliefPropagationCommon):
operators) belief propagation. Allows messages reuse. This version assumes
no hyper indices (i.e. a standard PEPS like tensor network).
Potential use cases for D2BP and a PEPS like tensor network are:
- globally compressing it from bond dimension ``D`` to ``D'``
- eagerly applying gates and locally compressing back to ``D``
- sampling configurations
- estimating the norm of the tensor network
Parameters
----------
tn : TensorNetwork
Expand Down
259 changes: 259 additions & 0 deletions quimb/experimental/belief_propagation/l1bp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import autoray as ar

import quimb.tensor as qtn

from .bp_common import (
BeliefPropagationCommon,
create_lazy_community_edge_map,
combine_local_contractions,
)


class L1BP(BeliefPropagationCommon):
"""Lazy 1-norm belief propagation. BP is run between groups of tensors
defined by ``site_tags``. The message updates are lazy contractions.
Parameters
----------
tn : TensorNetwork
The tensor network to run BP on.
site_tags : sequence of str, optional
The tags identifying the sites in ``tn``, each tag forms a region,
which should not overlap. If the tensor network is structured, then
these are inferred automatically.
damping : float, optional
The damping parameter to use, defaults to no damping.
local_convergence : bool, optional
Whether to allow messages to locally converge - i.e. if all their
input messages have converged then stop updating them.
optimize : str or PathOptimizer, optional
The path optimizer to use when contracting the messages.
contract_opts
Other options supplied to ``cotengra.array_contract``.
"""

def __init__(
self,
tn,
site_tags=None,
damping=0.0,
local_convergence=True,
optimize="auto-hq",
message_init_function=None,
**contract_opts,
):
self.backend = next(t.backend for t in tn)
self.damping = damping
self.local_convergence = local_convergence
self.optimize = optimize
self.contract_opts = contract_opts

if site_tags is None:
self.site_tags = tuple(tn.site_tags)
else:
self.site_tags = tuple(site_tags)

(
self.edges,
self.neighbors,
self.local_tns,
self.touch_map,
) = create_lazy_community_edge_map(tn, site_tags)
self.touched = set()

self._abs = ar.get_lib_fn(self.backend, "abs")
self._max = ar.get_lib_fn(self.backend, "max")
self._sum = ar.get_lib_fn(self.backend, "sum")
_real = ar.get_lib_fn(self.backend, "real")
_argmax = ar.get_lib_fn(self.backend, "argmax")
_reshape = ar.get_lib_fn(self.backend, "reshape")
self._norm = ar.get_lib_fn(self.backend, "linalg.norm")

def _normalize(x):
return x / self._sum(x)
# return x / self._norm(x)
# return x / self._max(x)
# fx = _reshape(x, (-1,))
# return x / fx[_argmax(self._abs(_real(fx)))]

def _distance(x, y):
return self._sum(self._abs(x - y))

self._normalize = _normalize
self._distance = _distance

# for each meta bond create initial messages
self.messages = {}
for pair, bix in self.edges.items():
# compute leftwards and rightwards messages
for i, j in (sorted(pair), sorted(pair, reverse=True)):
tn_i = self.local_tns[i]
# initial message just sums over dangling bonds

if message_init_function is None:
tm = tn_i.contract(
output_inds=bix,
optimize=self.optimize,
drop_tags=True,
**self.contract_opts,
)
# normalize
tm.modify(apply=self._normalize)
else:
shape = tuple(tn_i.ind_size(ix) for ix in bix)
tm = qtn.Tensor(
data=message_init_function(shape),
inds=bix,
)

self.messages[i, j] = tm

# compute the contractions
self.contraction_tns = {}
for pair, bix in self.edges.items():
# for each meta bond compute left and right contractions
for i, j in (sorted(pair), sorted(pair, reverse=True)):
tn_i = self.local_tns[i].copy()
# attach incoming messages to dangling bonds
tks = [
self.messages[k, i] for k in self.neighbors[i] if k != j
]
# virtual so we can modify messages tensors inplace
tn_i_to_j = qtn.TensorNetwork((tn_i, *tks), virtual=True)
self.contraction_tns[i, j] = tn_i_to_j

def iterate(self, tol=5e-6):
if (not self.local_convergence) or (not self.touched):
# assume if asked to iterate that we want to check all messages
self.touched.update(
pair for edge in self.edges for pair in (edge, edge[::-1])
)

ncheck = len(self.touched)
new_data = {}
while self.touched:
i, j = self.touched.pop()

bix = self.edges[(i, j) if i < j else (j, i)]
tn_i_to_j = self.contraction_tns[i, j]
tm_new = tn_i_to_j.contract(
output_inds=bix,
optimize=self.optimize,
**self.contract_opts,
)
m = self._normalize(tm_new.data)

new_data[i, j] = m

nconv = 0
max_mdiff = -1.0
for key, data in new_data.items():
tm = self.messages[key]

if self.damping != 0.0:
data = (1 - self.damping) * data + self.damping * tm.data

mdiff = float(self._distance(tm.data, data))

if mdiff > tol:
# mark touching messages for update
self.touched.update(self.touch_map[key])
else:
nconv += 1

max_mdiff = max(max_mdiff, mdiff)
tm.modify(data=data)

return nconv, ncheck, max_mdiff

def contract(self, strip_exponent=False):
tvals = []
for site, tn_ic in self.local_tns.items():
if site in self.neighbors:
tval = qtn.tensor_contract(
*tn_ic,
*(self.messages[k, site] for k in self.neighbors[site]),
optimize=self.optimize,
**self.contract_opts,
)
else:
# site exists but has no neighbors
tval = tn_ic.contract(
output_inds=(),
optimize=self.optimize,
**self.contract_opts,
)
tvals.append(tval)

mvals = []
for i, j in self.edges:
mval = qtn.tensor_contract(
self.messages[i, j],
self.messages[j, i],
optimize=self.optimize,
**self.contract_opts,
)
mvals.append(mval)

return combine_local_contractions(
tvals, mvals, self.backend, strip_exponent=strip_exponent
)


def contract_l1bp(
tn,
max_iterations=1000,
tol=5e-6,
site_tags=None,
damping=0.0,
local_convergence=True,
optimize="auto-hq",
strip_exponent=False,
progbar=False,
**contract_opts,
):
"""Estimate the contraction of ``tn`` using lazy 1-norm belief propagation.
Parameters
----------
tn : TensorNetwork
The tensor network to contract.
max_iterations : int, optional
The maximum number of iterations to perform.
tol : float, optional
The convergence tolerance for messages.
site_tags : sequence of str, optional
The tags identifying the sites in ``tn``, each tag forms a region. If
the tensor network is structured, then these are inferred
automatically.
damping : float, optional
The damping parameter to use, defaults to no damping.
local_convergence : bool, optional
Whether to allow messages to locally converge - i.e. if all their
input messages have converged then stop updating them.
optimize : str or PathOptimizer, optional
The path optimizer to use when contracting the messages.
progbar : bool, optional
Whether to show a progress bar.
strip_exponent : bool, optional
Whether to strip the exponent from the final result. If ``True``
then the returned result is ``(mantissa, exponent)``.
contract_opts
Other options supplied to ``cotengra.array_contract``.
"""
bp = L1BP(
tn,
site_tags=site_tags,
damping=damping,
local_convergence=local_convergence,
optimize=optimize,
**contract_opts,
)
bp.run(
max_iterations=max_iterations,
tol=tol,
progbar=progbar,
)
return bp.contract(
strip_exponent=strip_exponent,
)
91 changes: 91 additions & 0 deletions tests/test_tensor/test_belief_propagation/test_l1bp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import pytest

import quimb as qu
import quimb.tensor as qtn
from quimb.experimental.belief_propagation.l1bp import contract_l1bp
from quimb.experimental.belief_propagation.d2bp import contract_d2bp


@pytest.mark.parametrize("dtype", ["float32", "complex64"])
def test_contract_tree_exact(dtype):
tn = qtn.TN_rand_tree(10, 3, seed=42, dtype=dtype)
Z_ex = tn.contract()
Z_bp = contract_l1bp(tn)
assert Z_ex == pytest.approx(Z_bp, rel=5e-6)


@pytest.mark.parametrize("dtype", ["float32", "complex64"])
@pytest.mark.parametrize("damping", [0.0, 0.1])
def test_contract_loopy_approx(dtype, damping):
tn = qtn.TN2D_rand(3, 4, 5, dtype=dtype, dist="uniform")
Z_ex = tn.contract()
Z_bp = contract_l1bp(tn, damping=damping)
assert Z_ex == pytest.approx(Z_bp, rel=0.1)


@pytest.mark.parametrize("dtype", ["float32", "complex64"])
@pytest.mark.parametrize("damping", [0.0, 0.1])
def test_contract_double_loopy_approx(dtype, damping):
peps = qtn.PEPS.rand(4, 3, 2, seed=42, dtype=dtype)
tn = peps.H & peps
Z_ex = tn.contract()
Z_bp1 = contract_l1bp(tn, damping=damping)
assert Z_bp1 == pytest.approx(Z_ex, rel=0.3)
# compare with 2-norm BP on the peps directly
Z_bp2 = contract_d2bp(peps)
assert Z_bp1 == pytest.approx(Z_bp2, rel=5e-6)


@pytest.mark.parametrize("dtype", ["float32", "complex64"])
def test_contract_tree_triple_sandwich_exact(dtype):
edges = qtn.edges_tree_rand(20, 3, seed=42)
ket = qtn.TN_from_edges_rand(
edges,
3,
phys_dim=2,
seed=42,
site_ind_id="k{}",
dtype=dtype,
)
op = qtn.TN_from_edges_rand(
edges,
2,
phys_dim=2,
seed=42,
site_ind_id=("k{}", "b{}"),
dtype=dtype,
)
bra = qtn.TN_from_edges_rand(
edges,
3,
phys_dim=2,
seed=42,
site_ind_id="b{}",
dtype=dtype,
)
tn = bra.H | op | ket
Z_ex = tn.contract()
Z_bp = contract_l1bp(tn)
assert Z_ex == pytest.approx(Z_bp, rel=5e-6)

@pytest.mark.parametrize("dtype", ["float32", "complex64"])
@pytest.mark.parametrize("damping", [0.0, 0.1])
def test_contract_tree_triple_sandwich_loopy_approx(dtype, damping):
edges = qtn.edges_2d_hexagonal(2, 3)
ket = qtn.TN_from_edges_rand(
edges,
3,
phys_dim=2,
seed=42,
site_ind_id="k{}",
dtype=dtype,
# make the wavefunction postive to make easier
dist='uniform',
)
ket /= (ket.H @ ket)**0.5

G_ket = ket.gate(qu.pauli('Z'), [(1, 1, 'A')], propagate_tags="sites")
tn = ket.H | G_ket
Z_ex = tn.contract()
Z_bp = contract_l1bp(tn, damping=damping)
assert Z_bp == pytest.approx(Z_ex, rel=0.5)

0 comments on commit 8e6e290

Please sign in to comment.