From 8e6e290c6d2c25c2847497ca5dfd1d73104baf75 Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Fri, 20 Oct 2023 12:06:22 -0700 Subject: [PATCH] add lazy 1-norm belief propagation --- quimb/experimental/belief_propagation/d2bp.py | 8 + quimb/experimental/belief_propagation/l1bp.py | 259 ++++++++++++++++++ .../test_belief_propagation/test_l1bp.py | 91 ++++++ 3 files changed, 358 insertions(+) create mode 100644 quimb/experimental/belief_propagation/l1bp.py create mode 100644 tests/test_tensor/test_belief_propagation/test_l1bp.py diff --git a/quimb/experimental/belief_propagation/d2bp.py b/quimb/experimental/belief_propagation/d2bp.py index d66a67b6..9338f94d 100644 --- a/quimb/experimental/belief_propagation/d2bp.py +++ b/quimb/experimental/belief_propagation/d2bp.py @@ -12,6 +12,14 @@ class D2BP(BeliefPropagationCommon): operators) belief propagation. Allows messages reuse. This version assumes no hyper indices (i.e. a standard PEPS like tensor network). + Potential use cases for D2BP and a PEPS like tensor network are: + + - globally compressing it from bond dimension ``D`` to ``D'`` + - eagerly applying gates and locally compressing back to ``D`` + - sampling configurations + - estimating the norm of the tensor network + + Parameters ---------- tn : TensorNetwork diff --git a/quimb/experimental/belief_propagation/l1bp.py b/quimb/experimental/belief_propagation/l1bp.py new file mode 100644 index 00000000..f7382457 --- /dev/null +++ b/quimb/experimental/belief_propagation/l1bp.py @@ -0,0 +1,259 @@ +import autoray as ar + +import quimb.tensor as qtn + +from .bp_common import ( + BeliefPropagationCommon, + create_lazy_community_edge_map, + combine_local_contractions, +) + + +class L1BP(BeliefPropagationCommon): + """Lazy 1-norm belief propagation. BP is run between groups of tensors + defined by ``site_tags``. The message updates are lazy contractions. + + Parameters + ---------- + tn : TensorNetwork + The tensor network to run BP on. + site_tags : sequence of str, optional + The tags identifying the sites in ``tn``, each tag forms a region, + which should not overlap. If the tensor network is structured, then + these are inferred automatically. + damping : float, optional + The damping parameter to use, defaults to no damping. + local_convergence : bool, optional + Whether to allow messages to locally converge - i.e. if all their + input messages have converged then stop updating them. + optimize : str or PathOptimizer, optional + The path optimizer to use when contracting the messages. + contract_opts + Other options supplied to ``cotengra.array_contract``. + """ + + def __init__( + self, + tn, + site_tags=None, + damping=0.0, + local_convergence=True, + optimize="auto-hq", + message_init_function=None, + **contract_opts, + ): + self.backend = next(t.backend for t in tn) + self.damping = damping + self.local_convergence = local_convergence + self.optimize = optimize + self.contract_opts = contract_opts + + if site_tags is None: + self.site_tags = tuple(tn.site_tags) + else: + self.site_tags = tuple(site_tags) + + ( + self.edges, + self.neighbors, + self.local_tns, + self.touch_map, + ) = create_lazy_community_edge_map(tn, site_tags) + self.touched = set() + + self._abs = ar.get_lib_fn(self.backend, "abs") + self._max = ar.get_lib_fn(self.backend, "max") + self._sum = ar.get_lib_fn(self.backend, "sum") + _real = ar.get_lib_fn(self.backend, "real") + _argmax = ar.get_lib_fn(self.backend, "argmax") + _reshape = ar.get_lib_fn(self.backend, "reshape") + self._norm = ar.get_lib_fn(self.backend, "linalg.norm") + + def _normalize(x): + return x / self._sum(x) + # return x / self._norm(x) + # return x / self._max(x) + # fx = _reshape(x, (-1,)) + # return x / fx[_argmax(self._abs(_real(fx)))] + + def _distance(x, y): + return self._sum(self._abs(x - y)) + + self._normalize = _normalize + self._distance = _distance + + # for each meta bond create initial messages + self.messages = {} + for pair, bix in self.edges.items(): + # compute leftwards and rightwards messages + for i, j in (sorted(pair), sorted(pair, reverse=True)): + tn_i = self.local_tns[i] + # initial message just sums over dangling bonds + + if message_init_function is None: + tm = tn_i.contract( + output_inds=bix, + optimize=self.optimize, + drop_tags=True, + **self.contract_opts, + ) + # normalize + tm.modify(apply=self._normalize) + else: + shape = tuple(tn_i.ind_size(ix) for ix in bix) + tm = qtn.Tensor( + data=message_init_function(shape), + inds=bix, + ) + + self.messages[i, j] = tm + + # compute the contractions + self.contraction_tns = {} + for pair, bix in self.edges.items(): + # for each meta bond compute left and right contractions + for i, j in (sorted(pair), sorted(pair, reverse=True)): + tn_i = self.local_tns[i].copy() + # attach incoming messages to dangling bonds + tks = [ + self.messages[k, i] for k in self.neighbors[i] if k != j + ] + # virtual so we can modify messages tensors inplace + tn_i_to_j = qtn.TensorNetwork((tn_i, *tks), virtual=True) + self.contraction_tns[i, j] = tn_i_to_j + + def iterate(self, tol=5e-6): + if (not self.local_convergence) or (not self.touched): + # assume if asked to iterate that we want to check all messages + self.touched.update( + pair for edge in self.edges for pair in (edge, edge[::-1]) + ) + + ncheck = len(self.touched) + new_data = {} + while self.touched: + i, j = self.touched.pop() + + bix = self.edges[(i, j) if i < j else (j, i)] + tn_i_to_j = self.contraction_tns[i, j] + tm_new = tn_i_to_j.contract( + output_inds=bix, + optimize=self.optimize, + **self.contract_opts, + ) + m = self._normalize(tm_new.data) + + new_data[i, j] = m + + nconv = 0 + max_mdiff = -1.0 + for key, data in new_data.items(): + tm = self.messages[key] + + if self.damping != 0.0: + data = (1 - self.damping) * data + self.damping * tm.data + + mdiff = float(self._distance(tm.data, data)) + + if mdiff > tol: + # mark touching messages for update + self.touched.update(self.touch_map[key]) + else: + nconv += 1 + + max_mdiff = max(max_mdiff, mdiff) + tm.modify(data=data) + + return nconv, ncheck, max_mdiff + + def contract(self, strip_exponent=False): + tvals = [] + for site, tn_ic in self.local_tns.items(): + if site in self.neighbors: + tval = qtn.tensor_contract( + *tn_ic, + *(self.messages[k, site] for k in self.neighbors[site]), + optimize=self.optimize, + **self.contract_opts, + ) + else: + # site exists but has no neighbors + tval = tn_ic.contract( + output_inds=(), + optimize=self.optimize, + **self.contract_opts, + ) + tvals.append(tval) + + mvals = [] + for i, j in self.edges: + mval = qtn.tensor_contract( + self.messages[i, j], + self.messages[j, i], + optimize=self.optimize, + **self.contract_opts, + ) + mvals.append(mval) + + return combine_local_contractions( + tvals, mvals, self.backend, strip_exponent=strip_exponent + ) + + +def contract_l1bp( + tn, + max_iterations=1000, + tol=5e-6, + site_tags=None, + damping=0.0, + local_convergence=True, + optimize="auto-hq", + strip_exponent=False, + progbar=False, + **contract_opts, +): + """Estimate the contraction of ``tn`` using lazy 1-norm belief propagation. + + Parameters + ---------- + tn : TensorNetwork + The tensor network to contract. + max_iterations : int, optional + The maximum number of iterations to perform. + tol : float, optional + The convergence tolerance for messages. + site_tags : sequence of str, optional + The tags identifying the sites in ``tn``, each tag forms a region. If + the tensor network is structured, then these are inferred + automatically. + damping : float, optional + The damping parameter to use, defaults to no damping. + local_convergence : bool, optional + Whether to allow messages to locally converge - i.e. if all their + input messages have converged then stop updating them. + optimize : str or PathOptimizer, optional + The path optimizer to use when contracting the messages. + progbar : bool, optional + Whether to show a progress bar. + strip_exponent : bool, optional + Whether to strip the exponent from the final result. If ``True`` + then the returned result is ``(mantissa, exponent)``. + contract_opts + Other options supplied to ``cotengra.array_contract``. + """ + bp = L1BP( + tn, + site_tags=site_tags, + damping=damping, + local_convergence=local_convergence, + optimize=optimize, + **contract_opts, + ) + bp.run( + max_iterations=max_iterations, + tol=tol, + progbar=progbar, + ) + return bp.contract( + strip_exponent=strip_exponent, + ) diff --git a/tests/test_tensor/test_belief_propagation/test_l1bp.py b/tests/test_tensor/test_belief_propagation/test_l1bp.py new file mode 100644 index 00000000..569fad39 --- /dev/null +++ b/tests/test_tensor/test_belief_propagation/test_l1bp.py @@ -0,0 +1,91 @@ +import pytest + +import quimb as qu +import quimb.tensor as qtn +from quimb.experimental.belief_propagation.l1bp import contract_l1bp +from quimb.experimental.belief_propagation.d2bp import contract_d2bp + + +@pytest.mark.parametrize("dtype", ["float32", "complex64"]) +def test_contract_tree_exact(dtype): + tn = qtn.TN_rand_tree(10, 3, seed=42, dtype=dtype) + Z_ex = tn.contract() + Z_bp = contract_l1bp(tn) + assert Z_ex == pytest.approx(Z_bp, rel=5e-6) + + +@pytest.mark.parametrize("dtype", ["float32", "complex64"]) +@pytest.mark.parametrize("damping", [0.0, 0.1]) +def test_contract_loopy_approx(dtype, damping): + tn = qtn.TN2D_rand(3, 4, 5, dtype=dtype, dist="uniform") + Z_ex = tn.contract() + Z_bp = contract_l1bp(tn, damping=damping) + assert Z_ex == pytest.approx(Z_bp, rel=0.1) + + +@pytest.mark.parametrize("dtype", ["float32", "complex64"]) +@pytest.mark.parametrize("damping", [0.0, 0.1]) +def test_contract_double_loopy_approx(dtype, damping): + peps = qtn.PEPS.rand(4, 3, 2, seed=42, dtype=dtype) + tn = peps.H & peps + Z_ex = tn.contract() + Z_bp1 = contract_l1bp(tn, damping=damping) + assert Z_bp1 == pytest.approx(Z_ex, rel=0.3) + # compare with 2-norm BP on the peps directly + Z_bp2 = contract_d2bp(peps) + assert Z_bp1 == pytest.approx(Z_bp2, rel=5e-6) + + +@pytest.mark.parametrize("dtype", ["float32", "complex64"]) +def test_contract_tree_triple_sandwich_exact(dtype): + edges = qtn.edges_tree_rand(20, 3, seed=42) + ket = qtn.TN_from_edges_rand( + edges, + 3, + phys_dim=2, + seed=42, + site_ind_id="k{}", + dtype=dtype, + ) + op = qtn.TN_from_edges_rand( + edges, + 2, + phys_dim=2, + seed=42, + site_ind_id=("k{}", "b{}"), + dtype=dtype, + ) + bra = qtn.TN_from_edges_rand( + edges, + 3, + phys_dim=2, + seed=42, + site_ind_id="b{}", + dtype=dtype, + ) + tn = bra.H | op | ket + Z_ex = tn.contract() + Z_bp = contract_l1bp(tn) + assert Z_ex == pytest.approx(Z_bp, rel=5e-6) + +@pytest.mark.parametrize("dtype", ["float32", "complex64"]) +@pytest.mark.parametrize("damping", [0.0, 0.1]) +def test_contract_tree_triple_sandwich_loopy_approx(dtype, damping): + edges = qtn.edges_2d_hexagonal(2, 3) + ket = qtn.TN_from_edges_rand( + edges, + 3, + phys_dim=2, + seed=42, + site_ind_id="k{}", + dtype=dtype, + # make the wavefunction postive to make easier + dist='uniform', + ) + ket /= (ket.H @ ket)**0.5 + + G_ket = ket.gate(qu.pauli('Z'), [(1, 1, 'A')], propagate_tags="sites") + tn = ket.H | G_ket + Z_ex = tn.contract() + Z_bp = contract_l1bp(tn, damping=damping) + assert Z_bp == pytest.approx(Z_ex, rel=0.5)