From d3cdd9ead6276bf6de256d80f4027ce452374cce Mon Sep 17 00:00:00 2001 From: Johnnie Gray <johnniemcgray@gmail.com> Date: Tue, 17 Oct 2023 16:24:40 -0700 Subject: [PATCH] add dense 2-norm belief propagation --- .../belief_propagation/__init__.py | 91 +++ .../belief_propagation/bp_common.py | 264 ++++++++ quimb/experimental/belief_propagation/d2bp.py | 604 ++++++++++++++++++ .../test_belief_propagation/__init__.py | 0 .../test_belief_propagation/test_d2bp.py | 48 ++ 5 files changed, 1007 insertions(+) create mode 100644 quimb/experimental/belief_propagation/__init__.py create mode 100644 quimb/experimental/belief_propagation/bp_common.py create mode 100644 quimb/experimental/belief_propagation/d2bp.py create mode 100644 tests/test_tensor/test_belief_propagation/__init__.py create mode 100644 tests/test_tensor/test_belief_propagation/test_d2bp.py diff --git a/quimb/experimental/belief_propagation/__init__.py b/quimb/experimental/belief_propagation/__init__.py new file mode 100644 index 00000000..6582fe46 --- /dev/null +++ b/quimb/experimental/belief_propagation/__init__.py @@ -0,0 +1,91 @@ +"""Belief propagation (BP) routines. There are three potential categorizations +of BP and each combination of them is potentially valid specific algorithm. + +1-norm vs 2-norm BP +------------------- + +- 1-norm (normal): BP runs directly on the tensor network, messages have size + ``d`` where ``d`` is the size of the bond(s) connecting two tensors or + regions. +- 2-norm (quantum): BP runs on the squared tensor network, messages have size + ``d^2`` where ``d`` is the size of the bond(s) connecting two tensors or + regions. Each local tensor or region is partially traced (over dangling + indices) with its conjugate to create a single node. + + +Graph vs Hypergraph BP +---------------------- + +- Graph (simple): the tensor network lives on a graph, where indices either + appear on two tensors (a bond), or appear on a single tensor (are outputs). + In this case, messages are exchanged directly between tensors. +- Hypergraph: the tensor network lives on a hypergraph, where indices can + appear on any number of tensors. In this case, the update procedure is two + parts, first all 'tensor' messages are computed, these are then used in the + second step to compute all the 'index' messages, which are then fed back into + the 'tensor' message update and so forth. For 2-norm BP one likely needs to + specify which indices are outputs and should be traced over. + +The hypergraph case of course includes the graph case, but since the 'index' +message update is simply the identity, it is convenient to have a separate +simpler implementation, where the standard TN bond vs physical index +definitions hold. + + +Dense vs Vectorized vs Lazy BP +------------------------------ + +- Dense: each node is a single tensor, or pair of tensors for 2-norm BP. If all + multibonds have been fused, then each message is a vector (1-norm case) or + matrix (2-norm case). +- Vectorized: the same as the above, but all matching tensor update and message + updates are stacked and performed simultaneously. This can be enormously more + efficient for large numbers of small tensors. +- Lazy: each node is potentially a tensor network itself with arbitrary inner + structure and number of bonds connecting to other nodes. The message are + generally tensors and each update is a lazy contraction, which is potentially + much cheaper / requires less memory than forming the 'dense' node for large + tensors. + +(There is also the MPS flavor where each node has a 1D structure and the +messages are matrix product states, with updates involving compression.) + + +Overall that gives 12 possible BP flavors, some implemented here: + +- [x] (HD1BP) hyper, dense, 1-norm - this is the standard BP algorithm +- [x] (HD2BP) hyper, dense, 2-norm +- [x] (HV1BP) hyper, vectorized, 1-norm +- [ ] (HV2BP) hyper, vectorized, 2-norm +- [ ] (HL1BP) hyper, lazy, 1-norm +- [ ] (HL2BP) hyper, lazy, 2-norm +- [ ] (D1BP) simple, dense, 1-norm +- [x] (D2BP) simple, dense, 2-norm - this is the standard PEPS BP algorithm +- [ ] (V1BP) simple, vectorized, 1-norm +- [ ] (V2BP) simple, vectorized, 2-norm +- [x] (L1BP) simple, lazy, 1-norm +- [x] (L2BP) simple, lazy, 2-norm + +The 2-norm methods can be used to compress bonds or estimate the 2-norm. +The 1-norm methods can be used to estimate the 1-norm, i.e. contracted value. +Both methods can be used to compute index marginals and thus perform sampling. + +The vectorized methods can be extremely fast for large numbers of small +tensors, but do currently require all dimensions to match. + +The dense and lazy methods can can converge messages *locally*, i.e. only +update messages adjacent to messages which have changed. +""" + +from .bp_common import initialize_hyper_messages +from .d2bp import D2BP, contract_d2bp, compress_d2bp, sample_d2bp + +__all__ = ( + "initialize_hyper_messages", + "D2BP", + "contract_d2bp", + "compress_d2bp", + "sample_d2bp", + "HD1BP", + "HV1BP", +) diff --git a/quimb/experimental/belief_propagation/bp_common.py b/quimb/experimental/belief_propagation/bp_common.py new file mode 100644 index 00000000..c8cf9b47 --- /dev/null +++ b/quimb/experimental/belief_propagation/bp_common.py @@ -0,0 +1,264 @@ +import functools +import operator + +import autoray as ar + +import quimb.tensor as qtn + + +class RollingDiffMean: + """Tracker for the absolute rolling mean of diffs between values, to + assess effective convergence of BP above actual message tolerance. + """ + + def __init__(self, size=16): + self.size = size + self.diffs = [] + self.last_x = None + self.dxsum = 0.0 + + def update(self, x): + if self.last_x is not None: + dx = x - self.last_x + self.diffs.append(dx) + self.dxsum += dx / self.size + if len(self.diffs) > self.size: + dx = self.diffs.pop(0) + self.dxsum -= dx / self.size + self.last_x = x + + def absmeandiff(self): + if len(self.diffs) < self.size: + return float("inf") + return abs(self.dxsum) + + +class BeliefPropagationCommon: + """Common interfaces for belief propagation algorithms. + + Parameters + ---------- + max_iterations : int, optional + The maximum number of iterations to perform. + tol : float, optional + The convergence tolerance for messages. + progbar : bool, optional + Whether to show a progress bar. + """ + + def run(self, max_iterations=1000, tol=5e-6, progbar=False): + if progbar: + import tqdm + + pbar = tqdm.tqdm() + else: + pbar = None + + try: + it = 0 + rdm = RollingDiffMean() + self.converged = False + while not self.converged and it < max_iterations: + # can only converge if tol > 0.0 + nconv, ncheck, max_mdiff = self.iterate(tol=tol) + rdm.update(max_mdiff) + self.converged = (max_mdiff < tol) or (rdm.absmeandiff() < tol) + it += 1 + + if pbar is not None: + pbar.set_description( + f"nconv: {nconv}/{ncheck} max|dM|={max_mdiff:.2e}", + refresh=False, + ) + pbar.update() + + finally: + if pbar is not None: + pbar.close() + + if tol != 0.0 and not self.converged: + import warnings + + warnings.warn( + f"Belief propagation did not converge after {max_iterations} " + f"iterations, tol={tol:.2e}, max|dM|={max_mdiff:.2e}." + ) + + +def prod(xs): + """Product of all elements in ``xs``.""" + return functools.reduce(operator.mul, xs) + + +def initialize_hyper_messages( + tn, + fill_fn=None, + smudge_factor=1e-12 +): + """Initialize messages for belief propagation, this is equivalent to doing + a single round of belief propagation with uniform messages. + + Parameters + ---------- + tn : TensorNetwork + The tensor network to initialize messages for. + fill_fn : callable, optional + A function to fill the messages with, of signature ``fill_fn(shape)``. + smudge_factor : float, optional + A small number to add to the messages to avoid numerical issues. + + Returns + ------- + messages : dict + The initial messages. For every index and tensor id pair, there will + be a message to and from with keys ``(ix, tid)`` and ``(tid, ix)``. + """ + from quimb.tensor.contraction import array_contract + + backend = ar.infer_backend(next(t.data for t in tn)) + _sum = ar.get_lib_fn(backend, "sum") + + messages = {} + + # compute first messages from tensors to indices + for tid, t in tn.tensor_map.items(): + k_inputs = tuple(range(t.ndim)) + for i, ix in enumerate(t.inds): + if fill_fn is None: + # sum over all other indices to get initial message + m = array_contract( + arrays=(t.data,), + inputs=(k_inputs,), + output=(i,), + ) + # normalize and insert + messages[tid, ix] = m / _sum(m) + else: + d = t.ind_size(ix) + m = fill_fn((d,)) + messages[tid, ix] = m / _sum(m) + + # compute first messages from indices to tensors + for ix, tids in tn.ind_map.items(): + ms = [messages[tid, ix] for tid in tids] + mp = prod(ms) + for mi, tid in zip(ms, tids): + m = mp / (mi + smudge_factor) + # normalize and insert + messages[ix, tid] = m / _sum(m) + + return messages + + +def combine_local_contractions( + tvals, + mvals, + backend, + strip_exponent=False, + check_for_zero=True, +): + _abs = ar.get_lib_fn(backend, "abs") + _log10 = ar.get_lib_fn(backend, "log10") + + mantissa = 1 + exponent = 0 + for vt in tvals: + avt = _abs(vt) + + if check_for_zero and (avt == 0.0): + if strip_exponent: + return 0.0, 0.0 + else: + return 0.0 + + mantissa = mantissa * (vt / avt) + exponent = exponent + _log10(avt) + for mt in mvals: + amt = _abs(mt) + mantissa = mantissa / (mt / amt) + exponent = exponent - _log10(amt) + + if strip_exponent: + return mantissa, exponent + else: + return mantissa * 10**exponent + + +def maybe_get_thread_pool(thread_pool): + """Get a thread pool if requested.""" + if thread_pool is False: + return None + + if thread_pool is True: + import quimb as qu + + return qu.get_thread_pool() + + if isinstance(thread_pool, int): + import quimb as qu + + return qu.get_thread_pool(thread_pool) + + return thread_pool + + +def create_lazy_community_edge_map(tn, site_tags=None, rank_simplify=True): + """For lazy BP algorithms, create the data structures describing the + effective graph of the lazily grouped 'sites' given by ``site_tags``. + """ + if site_tags is None: + site_tags = set(tn.site_tags) + else: + site_tags = set(site_tags) + + edges = {} + neighbors = {} + local_tns = {} + touch_map = {} + + for ix in tn.ind_map: + ts = tn._inds_get(ix) + tags = {tag for t in ts for tag in t.tags if tag in site_tags} + if len(tags) >= 2: + i, j = tuple(sorted(tags)) + + if (i, j) in edges: + # already processed this edge + continue + + # add to neighbor map + neighbors.setdefault(i, []).append(j) + neighbors.setdefault(j, []).append(i) + + # get local TNs and compute bonds between them, + # rank simplify here also to prepare for contractions + try: + tn_i = local_tns[i] + except KeyError: + tn_i = local_tns[i] = tn.select(i, virtual=False) + if rank_simplify: + tn_i.rank_simplify_() + try: + tn_j = local_tns[j] + except KeyError: + tn_j = local_tns[j] = tn.select(j, virtual=False) + if rank_simplify: + tn_j.rank_simplify_() + + edges[i, j] = tuple(qtn.bonds(tn_i, tn_j)) + + for i, j in edges: + touch_map[(i, j)] = tuple((j, k) for k in neighbors[j] if k != i) + touch_map[(j, i)] = tuple((i, k) for k in neighbors[i] if k != j) + + if len(local_tns) != len(site_tags): + # handle potentially disconnected sites + for i in sorted(site_tags): + try: + tn_i = local_tns[i] = tn.select(i, virtual=False) + if rank_simplify: + tn_i.rank_simplify_() + except KeyError: + pass + + return edges, neighbors, local_tns, touch_map diff --git a/quimb/experimental/belief_propagation/d2bp.py b/quimb/experimental/belief_propagation/d2bp.py new file mode 100644 index 00000000..681d1870 --- /dev/null +++ b/quimb/experimental/belief_propagation/d2bp.py @@ -0,0 +1,604 @@ +import autoray as ar +import quimb.tensor as qtn + +from .bp_common import ( + BeliefPropagationCommon, + combine_local_contractions, +) + + +class D2BP(BeliefPropagationCommon): + """Dense (as in one tensor per site) 2-norm (as in for wavefunctions and + operators) belief propagation. Allows messages reuse. This version assumes + no hyper indices (i.e. a standard PEPS like tensor network). + + Parameters + ---------- + tn : TensorNetwork + The tensor network to form the 2-norm of and run BP on. + messages : dict[(str, int), array_like], optional + The initial messages to use, effectively defaults to all ones if not + specified. + output_inds : set[str], optional + The indices to consider as output (dangling) indices of the tn. + Computed automatically if not specified. + optimize : str or PathOptimizer, optional + The path optimizer to use when contracting the messages. + local_convergence : bool, optional + Whether to allow messages to locally converge - i.e. if all their + input messages have converged then stop updating them. + contract_opts + Other options supplied to ``cotengra.array_contract``. + """ + + def __init__( + self, + tn, + messages=None, + output_inds=None, + optimize="auto-hq", + local_convergence=True, + damping=0.0, + **contract_opts, + ): + from quimb.tensor.contraction import array_contract_expression + + self.tn = tn + self.contract_opts = contract_opts + self.contract_opts.setdefault("optimize", optimize) + self.local_convergence = local_convergence + self.damping = damping + + if output_inds is None: + self.output_inds = set(self.tn.outer_inds()) + else: + self.output_inds = set(output_inds) + + self.backend = next(t.backend for t in tn) + _abs = ar.get_lib_fn(self.backend, "abs") + _sum = ar.get_lib_fn(self.backend, "sum") + + def _normalize(x): + return x / _sum(x) + + def _distance(x, y): + return _sum(_abs(x - y)) + + self._normalize = _normalize + self._distance = _distance + + if messages is None: + self.messages = {} + else: + self.messages = messages + + # record which messages touch each others, for efficient updates + self.touch_map = {} + self.touched = set() + self.exprs = {} + + # populate any messages + for ix, tids in self.tn.ind_map.items(): + if ix in self.output_inds: + continue + + tida, tidb = tids + jx = ix + "*" + ta, tb = self.tn._tids_get(tida, tidb) + + for tid, t, t_in in ((tida, ta, tb), (tidb, tb, ta)): + this_touchmap = [] + for nx in t.inds: + if nx in self.output_inds or nx == ix: + continue + # where this message will be sent on to + (tidn,) = (n for n in self.tn.ind_map[nx] if n != tid) + this_touchmap.append((nx, tidn)) + self.touch_map[ix, tid] = this_touchmap + + if (ix, tid) not in self.messages: + tm = (t_in.reindex({ix: jx}).conj_() @ t_in).data + self.messages[ix, tid] = self._normalize(tm.data) + + # for efficiency setup all the contraction expressions ahead of time + for ix, tids in self.tn.ind_map.items(): + if ix in self.output_inds: + continue + + for tida, tidb in (sorted(tids), sorted(tids, reverse=True)): + ta = self.tn.tensor_map[tida] + kix = ta.inds + bix = tuple( + i if i in self.output_inds else i + "*" for i in kix + ) + inputs = [kix, bix] + data = [ta.data, ta.data.conj()] + shapes = [ta.shape, ta.shape] + for i in kix: + if (i != ix) and i not in self.output_inds: + inputs.append((i + "*", i)) + data.append((i, tida)) + shapes.append(self.messages[i, tida].shape) + + expr = array_contract_expression( + inputs=inputs, + output=(ix + "*", ix), + shapes=shapes, + **self.contract_opts, + ) + self.exprs[ix, tidb] = expr, data + + def update_touched_from_tids(self, *tids): + """Specify that the messages for the given ``tids`` have changed.""" + for tid in tids: + t = self.tn.tensor_map[tid] + for ix in t.inds: + if ix in self.output_inds: + continue + (ntid,) = (n for n in self.tn.ind_map[ix] if n != tid) + self.touched.add((ix, ntid)) + + def update_touched_from_tags(self, tags, which="any"): + """Specify that the messages for the messages touching ``tags`` have + changed. + """ + tids = self.tn._get_tids_from_tags(tags, which) + self.update_touched_from_tids(*tids) + + def update_touched_from_inds(self, inds, which="any"): + """Specify that the messages for the messages touching ``inds`` have + changed. + """ + tids = self.tn._get_tids_from_inds(inds, which) + self.update_touched_from_tids(*tids) + + def iterate(self, tol=5e-6): + """Perform a single iteration of dense 2-norm belief propagation.""" + + if (not self.local_convergence) or (not self.touched): + # assume if asked to iterate that we want to check all messages + self.touched.update(self.exprs.keys()) + + ncheck = len(self.touched) + new_messages = {} + while self.touched: + key = self.touched.pop() + expr, data = self.exprs[key] + m = expr(*data[:2], *(self.messages[mkey] for mkey in data[2:])) + # enforce hermiticity and normalize + m = m + ar.dag(m) + m = self._normalize(m) + + if self.damping > 0.0: + m = self._normalize( + # new message + (1 - self.damping) * m + + + # old message + self.damping * self.messages[key] + ) + + new_messages[key] = m + + # process modified messages + nconv = 0 + max_mdiff = -1.0 + for key, m in new_messages.items(): + mdiff = float(self._distance(m, self.messages[key])) + + if mdiff > tol: + # mark touching messages for update + self.touched.update(self.touch_map[key]) + else: + nconv += 1 + + max_mdiff = max(max_mdiff, mdiff) + self.messages[key] = m + + return nconv, ncheck, max_mdiff + + def compute_marginal(self, ind): + """Compute the marginal for the index ``ind``.""" + (tid,) = self.tn.ind_map[ind] + t = self.tn.tensor_map[tid] + + arrays = [t.data, ar.do("conj", t.data)] + k_input = [] + b_input = [] + m_inputs = [] + for j, jx in enumerate(t.inds, 1): + k_input.append(j) + + if jx == ind: + # output index -> take diagonal + output = (j,) + b_input.append(j) + else: + try: + # partial trace with message + m = self.messages[jx, tid] + arrays.append(m) + b_input.append(-j) + m_inputs.append((-j, j)) + except KeyError: + # direct partial trace + b_input.append(j) + + p = qtn.array_contract( + arrays, + inputs=(tuple(k_input), tuple(b_input), *m_inputs), + output=output, + **self.contract_opts, + ) + p = ar.do("real", p) + return p / ar.do("sum", p) + + def contract(self, strip_exponent=False): + """Estimate the total contraction, i.e. the 2-norm. + + Parameters + ---------- + strip_exponent : bool, optional + Whether to strip the exponent from the final result. If ``True`` + then the returned result is ``(mantissa, exponent)``. + + Returns + ------- + scalar or (scalar, float) + """ + tvals = [] + + for tid, t in self.tn.tensor_map.items(): + arrays = [t.data, ar.do("conj", t.data)] + k_input = [] + b_input = [] + m_inputs = [] + for i, ix in enumerate(t.inds, 1): + k_input.append(i) + if ix in self.output_inds: + b_input.append(i) + else: + b_input.append(-i) + m_inputs.append((-i, i)) + arrays.append(self.messages[ix, tid]) + + inputs = (tuple(k_input), tuple(b_input), *m_inputs) + output = () + tval = qtn.array_contract( + arrays, inputs, output, **self.contract_opts + ) + tvals.append(tval) + + mvals = [] + for ix, tids in self.tn.ind_map.items(): + if ix in self.output_inds: + continue + tida, tidb = tids + ml = self.messages[ix, tidb] + mr = self.messages[ix, tida] + mval = qtn.array_contract( + (ml, mr), ((1, 2), (1, 2)), (), **self.contract_opts + ) + mvals.append(mval) + + return combine_local_contractions( + tvals, mvals, self.backend, strip_exponent=strip_exponent + ) + + def compress( + self, + max_bond, + cutoff=0.0, + cutoff_mode=4, + renorm=0, + inplace=False, + ): + """Compress the initial tensor network using the current messages.""" + tn = self.tn if inplace else self.tn.copy() + + for ix, tids in tn.ind_map.items(): + if len(tids) != 2: + continue + tida, tidb = tids + + # messages are left and right factors squared already + ta = tn.tensor_map[tida] + dm = ta.ind_size(ix) + dl = ta.size // dm + ml = self.messages[ix, tidb] + Rl = qtn.decomp.squared_op_to_reduced_factor( + ml, dl, dm, right=True + ) + + tb = tn.tensor_map[tidb] + dr = tb.size // dm + mr = self.messages[ix, tida].T + Rr = qtn.decomp.squared_op_to_reduced_factor( + mr, dm, dr, right=False + ) + + # compute the compressors + Pl, Pr = qtn.decomp.compute_oblique_projectors( + Rl, + Rr, + max_bond=max_bond, + cutoff=cutoff, + cutoff_mode=cutoff_mode, + renorm=renorm, + ) + + # contract the compressors into the tensors + tn.tensor_map[tida].gate_(Pl.T, ix) + tn.tensor_map[tidb].gate_(Pr, ix) + + # update messages with projections + if inplace: + new_Ra = Rl @ Pl + new_Rb = Pr @ Rr + self.messages[ix, tidb] = ar.dag(new_Ra) @ new_Ra + self.messages[ix, tida] = new_Rb @ ar.dag(new_Rb) + + return tn + + +def contract_d2bp( + tn, + messages=None, + output_inds=None, + optimize="auto-hq", + local_convergence=True, + damping=0.0, + max_iterations=1000, + tol=5e-6, + strip_exponent=False, + progbar=False, + **contract_opts, +): + """Estimate the norm squared of ``tn`` using dense 2-norm belief + propagation. + + Parameters + ---------- + tn : TensorNetwork + The tensor network to form the 2-norm of and run BP on. + messages : dict[(str, int), array_like], optional + The initial messages to use, effectively defaults to all ones if not + specified. + max_iterations : int, optional + The maximum number of iterations to perform. + tol : float, optional + The convergence tolerance for messages. + output_inds : set[str], optional + The indices to consider as output (dangling) indices of the tn. + Computed automatically if not specified. + optimize : str or PathOptimizer, optional + The path optimizer to use when contracting the messages. + local_convergence : bool, optional + Whether to allow messages to locally converge - i.e. if all their + input messages have converged then stop updating them. + damping : float, optional + The damping parameter to use, defaults to no damping. + strip_exponent : bool, optional + Whether to strip the exponent from the final result. If ``True`` + then the returned result is ``(mantissa, exponent)``. + progbar : bool, optional + Whether to show a progress bar. + contract_opts + Other options supplied to ``cotengra.array_contract``. + + Returns + ------- + scalar or (scalar, float) + """ + bp = D2BP( + tn, + messages=messages, + output_inds=output_inds, + optimize=optimize, + local_convergence=local_convergence, + damping=damping, + **contract_opts, + ) + bp.run( + max_iterations=max_iterations, + tol=tol, + progbar=progbar, + ) + return bp.contract(strip_exponent=strip_exponent) + + +def compress_d2bp( + tn, + max_bond, + cutoff=0.0, + cutoff_mode="rsum2", + renorm=0, + messages=None, + output_inds=None, + optimize="auto-hq", + local_convergence=True, + damping=0.0, + max_iterations=1000, + tol=5e-6, + inplace=False, + progbar=False, + **contract_opts, +): + """Compress the tensor network ``tn`` using dense 2-norm belief + propagation. + + Parameters + ---------- + tn : TensorNetwork + The tensor network to form the 2-norm of, run BP on and then compress. + max_bond : int + The maximum bond dimension to compress to. + cutoff : float, optional + The cutoff to use when compressing. + cutoff_mode : int, optional + The cutoff mode to use when compressing. + messages : dict[(str, int), array_like], optional + The initial messages to use, effectively defaults to all ones if not + specified. + max_iterations : int, optional + The maximum number of iterations to perform. + tol : float, optional + The convergence tolerance for messages. + output_inds : set[str], optional + The indices to consider as output (dangling) indices of the tn. + Computed automatically if not specified. + optimize : str or PathOptimizer, optional + The path optimizer to use when contracting the messages. + local_convergence : bool, optional + Whether to allow messages to locally converge - i.e. if all their + input messages have converged then stop updating them. + damping : float, optional + The damping parameter to use, defaults to no damping. + inplace : bool, optional + Whether to perform the compression inplace. + progbar : bool, optional + Whether to show a progress bar. + contract_opts + Other options supplied to ``cotengra.array_contract``. + + Returns + ------- + TensorNetwork + """ + bp = D2BP( + tn, + messages=messages, + output_inds=output_inds, + optimize=optimize, + local_convergence=local_convergence, + damping=damping, + **contract_opts, + ) + bp.run( + max_iterations=max_iterations, + tol=tol, + progbar=progbar, + ) + return bp.compress( + max_bond=max_bond, + cutoff=cutoff, + cutoff_mode=cutoff_mode, + renorm=renorm, + inplace=inplace, + ) + + +def sample_d2bp( + tn, + output_inds=None, + messages=None, + max_iterations=100, + tol=1e-2, + bias=None, + seed=None, + local_convergence=True, + progbar=False, + **contract_opts, +): + """Sample a configuration from ``tn`` using dense 2-norm belief + propagation. + + Parameters + ---------- + tn : TensorNetwork + The tensor network to sample from. + output_inds : set[str], optional + Which indices to sample. + messages : dict[(str, int), array_like], optional + The initial messages to use, effectively defaults to all ones if not + specified. + max_iterations : int, optional + The maximum number of iterations to perform, per marginal. + tol : float, optional + The convergence tolerance for messages. + bias : float, optional + Bias the sampling towards more locally likely bit-strings. This is + done by raising the probability of each bit-string to this power. + seed : int, optional + A random seed for reproducibility. + local_convergence : bool, optional + Whether to allow messages to locally converge - i.e. if all their + input messages have converged then stop updating them. + progbar : bool, optional + Whether to show a progress bar. + contract_opts + Other options supplied to ``cotengra.array_contract``. + + Returns + ------- + config : dict[str, int] + The sampled configuration, a mapping of output indices to values. + tn_config : TensorNetwork + The tensor network with the sampled configuration applied. + omega : float + The BP probability of the sampled configuration. + """ + import numpy as np + + if output_inds is None: + output_inds = tn.outer_inds() + + rng = np.random.default_rng(seed) + config = {} + omega = 1.0 + + tn = tn.copy() + bp = D2BP( + tn, + messages=messages, + local_convergence=local_convergence, + **contract_opts, + ) + bp.run(max_iterations=max_iterations, tol=tol) + + marginals = dict.fromkeys(output_inds) + + if progbar: + import tqdm + + pbar = tqdm.tqdm(total=len(marginals)) + else: + pbar = None + + while marginals: + for ix in marginals: + marginals[ix] = bp.compute_marginal(ix) + + ix, p = max(marginals.items(), key=lambda x: max(x[1])) + p = ar.to_numpy(p) + + if bias is not None: + # bias distribution towards more locally likely bit-strings + p = p**bias + p /= np.sum(p) + + v = rng.choice([0, 1], p=p) + config[ix] = v + del marginals[ix] + + tids = tuple(tn.ind_map[ix]) + tn.isel_({ix: v}) + + omega *= p[v] + if progbar: + pbar.update(1) + pbar.set_description(f"{ix}->{v}", refresh=False) + + bp = D2BP( + tn, + messages=bp.messages, + local_convergence=local_convergence, + **contract_opts, + ) + bp.update_touched_from_tids(*tids) + bp.run(tol=tol, max_iterations=max_iterations) + + if progbar: + pbar.close() + + return config, tn, omega diff --git a/tests/test_tensor/test_belief_propagation/__init__.py b/tests/test_tensor/test_belief_propagation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_tensor/test_belief_propagation/test_d2bp.py b/tests/test_tensor/test_belief_propagation/test_d2bp.py new file mode 100644 index 00000000..0bdd9b06 --- /dev/null +++ b/tests/test_tensor/test_belief_propagation/test_d2bp.py @@ -0,0 +1,48 @@ +import pytest + +import quimb.tensor as qtn +from quimb.experimental.belief_propagation.d2bp import ( + contract_d2bp, + compress_d2bp, + sample_d2bp, +) + + +@pytest.mark.parametrize("damping", [0.0, 0.1, 0.5]) +def test_contract(damping): + peps = qtn.PEPS.rand(3, 4, 3, seed=42) + # normalize exactly + peps /= (peps.H @ peps) ** 0.5 + N_ap = contract_d2bp(peps, damping=damping) + assert N_ap == pytest.approx(1.0, rel=0.3) + + +@pytest.mark.parametrize("damping", [0.0, 0.1, 0.5]) +def test_compress(damping): + peps = qtn.PEPS.rand(3, 4, 3, seed=42) + # test that using the BP compression gives better fidelity than purely + # local, naive compression scheme + peps_c1 = peps.compress_all(max_bond=2) + peps_c2 = compress_d2bp(peps, max_bond=2, damping=damping) + fid1 = peps_c1.H @ peps_c2 + fid2 = peps_c2.H @ peps_c2 + assert fid2 > fid1 + + +def test_sample(): + peps = qtn.PEPS.rand(3, 4, 3, seed=42) + # normalize exactly + peps /= (peps.H @ peps) ** 0.5 + config, peps_config, omega = sample_d2bp(peps, seed=42) + assert all(ix in config for ix in peps.site_inds) + assert 0.0 < omega < 1.0 + assert peps_config.outer_inds() == () + + ptotal = 0.0 + nrepeat = 4 + for _ in range(nrepeat): + _, peps_config, _ = sample_d2bp(peps, seed=42) + ptotal += peps_config.contract()**2 + + # check we are doing better than random guessing + assert ptotal > nrepeat * 2**-peps.nsites