-
Notifications
You must be signed in to change notification settings - Fork 112
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
1,007 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
"""Belief propagation (BP) routines. There are three potential categorizations | ||
of BP and each combination of them is potentially valid specific algorithm. | ||
1-norm vs 2-norm BP | ||
------------------- | ||
- 1-norm (normal): BP runs directly on the tensor network, messages have size | ||
``d`` where ``d`` is the size of the bond(s) connecting two tensors or | ||
regions. | ||
- 2-norm (quantum): BP runs on the squared tensor network, messages have size | ||
``d^2`` where ``d`` is the size of the bond(s) connecting two tensors or | ||
regions. Each local tensor or region is partially traced (over dangling | ||
indices) with its conjugate to create a single node. | ||
Graph vs Hypergraph BP | ||
---------------------- | ||
- Graph (simple): the tensor network lives on a graph, where indices either | ||
appear on two tensors (a bond), or appear on a single tensor (are outputs). | ||
In this case, messages are exchanged directly between tensors. | ||
- Hypergraph: the tensor network lives on a hypergraph, where indices can | ||
appear on any number of tensors. In this case, the update procedure is two | ||
parts, first all 'tensor' messages are computed, these are then used in the | ||
second step to compute all the 'index' messages, which are then fed back into | ||
the 'tensor' message update and so forth. For 2-norm BP one likely needs to | ||
specify which indices are outputs and should be traced over. | ||
The hypergraph case of course includes the graph case, but since the 'index' | ||
message update is simply the identity, it is convenient to have a separate | ||
simpler implementation, where the standard TN bond vs physical index | ||
definitions hold. | ||
Dense vs Vectorized vs Lazy BP | ||
------------------------------ | ||
- Dense: each node is a single tensor, or pair of tensors for 2-norm BP. If all | ||
multibonds have been fused, then each message is a vector (1-norm case) or | ||
matrix (2-norm case). | ||
- Vectorized: the same as the above, but all matching tensor update and message | ||
updates are stacked and performed simultaneously. This can be enormously more | ||
efficient for large numbers of small tensors. | ||
- Lazy: each node is potentially a tensor network itself with arbitrary inner | ||
structure and number of bonds connecting to other nodes. The message are | ||
generally tensors and each update is a lazy contraction, which is potentially | ||
much cheaper / requires less memory than forming the 'dense' node for large | ||
tensors. | ||
(There is also the MPS flavor where each node has a 1D structure and the | ||
messages are matrix product states, with updates involving compression.) | ||
Overall that gives 12 possible BP flavors, some implemented here: | ||
- [x] (HD1BP) hyper, dense, 1-norm - this is the standard BP algorithm | ||
- [x] (HD2BP) hyper, dense, 2-norm | ||
- [x] (HV1BP) hyper, vectorized, 1-norm | ||
- [ ] (HV2BP) hyper, vectorized, 2-norm | ||
- [ ] (HL1BP) hyper, lazy, 1-norm | ||
- [ ] (HL2BP) hyper, lazy, 2-norm | ||
- [ ] (D1BP) simple, dense, 1-norm | ||
- [x] (D2BP) simple, dense, 2-norm - this is the standard PEPS BP algorithm | ||
- [ ] (V1BP) simple, vectorized, 1-norm | ||
- [ ] (V2BP) simple, vectorized, 2-norm | ||
- [x] (L1BP) simple, lazy, 1-norm | ||
- [x] (L2BP) simple, lazy, 2-norm | ||
The 2-norm methods can be used to compress bonds or estimate the 2-norm. | ||
The 1-norm methods can be used to estimate the 1-norm, i.e. contracted value. | ||
Both methods can be used to compute index marginals and thus perform sampling. | ||
The vectorized methods can be extremely fast for large numbers of small | ||
tensors, but do currently require all dimensions to match. | ||
The dense and lazy methods can can converge messages *locally*, i.e. only | ||
update messages adjacent to messages which have changed. | ||
""" | ||
|
||
from .bp_common import initialize_hyper_messages | ||
from .d2bp import D2BP, contract_d2bp, compress_d2bp, sample_d2bp | ||
|
||
__all__ = ( | ||
"initialize_hyper_messages", | ||
"D2BP", | ||
"contract_d2bp", | ||
"compress_d2bp", | ||
"sample_d2bp", | ||
"HD1BP", | ||
"HV1BP", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,264 @@ | ||
import functools | ||
import operator | ||
|
||
import autoray as ar | ||
|
||
import quimb.tensor as qtn | ||
|
||
|
||
class RollingDiffMean: | ||
"""Tracker for the absolute rolling mean of diffs between values, to | ||
assess effective convergence of BP above actual message tolerance. | ||
""" | ||
|
||
def __init__(self, size=16): | ||
self.size = size | ||
self.diffs = [] | ||
self.last_x = None | ||
self.dxsum = 0.0 | ||
|
||
def update(self, x): | ||
if self.last_x is not None: | ||
dx = x - self.last_x | ||
self.diffs.append(dx) | ||
self.dxsum += dx / self.size | ||
if len(self.diffs) > self.size: | ||
dx = self.diffs.pop(0) | ||
self.dxsum -= dx / self.size | ||
self.last_x = x | ||
|
||
def absmeandiff(self): | ||
if len(self.diffs) < self.size: | ||
return float("inf") | ||
return abs(self.dxsum) | ||
|
||
|
||
class BeliefPropagationCommon: | ||
"""Common interfaces for belief propagation algorithms. | ||
Parameters | ||
---------- | ||
max_iterations : int, optional | ||
The maximum number of iterations to perform. | ||
tol : float, optional | ||
The convergence tolerance for messages. | ||
progbar : bool, optional | ||
Whether to show a progress bar. | ||
""" | ||
|
||
def run(self, max_iterations=1000, tol=5e-6, progbar=False): | ||
if progbar: | ||
import tqdm | ||
|
||
pbar = tqdm.tqdm() | ||
else: | ||
pbar = None | ||
|
||
try: | ||
it = 0 | ||
rdm = RollingDiffMean() | ||
self.converged = False | ||
while not self.converged and it < max_iterations: | ||
# can only converge if tol > 0.0 | ||
nconv, ncheck, max_mdiff = self.iterate(tol=tol) | ||
rdm.update(max_mdiff) | ||
self.converged = (max_mdiff < tol) or (rdm.absmeandiff() < tol) | ||
it += 1 | ||
|
||
if pbar is not None: | ||
pbar.set_description( | ||
f"nconv: {nconv}/{ncheck} max|dM|={max_mdiff:.2e}", | ||
refresh=False, | ||
) | ||
pbar.update() | ||
|
||
finally: | ||
if pbar is not None: | ||
pbar.close() | ||
|
||
if tol != 0.0 and not self.converged: | ||
import warnings | ||
|
||
warnings.warn( | ||
f"Belief propagation did not converge after {max_iterations} " | ||
f"iterations, tol={tol:.2e}, max|dM|={max_mdiff:.2e}." | ||
) | ||
|
||
|
||
def prod(xs): | ||
"""Product of all elements in ``xs``.""" | ||
return functools.reduce(operator.mul, xs) | ||
|
||
|
||
def initialize_hyper_messages( | ||
tn, | ||
fill_fn=None, | ||
smudge_factor=1e-12 | ||
): | ||
"""Initialize messages for belief propagation, this is equivalent to doing | ||
a single round of belief propagation with uniform messages. | ||
Parameters | ||
---------- | ||
tn : TensorNetwork | ||
The tensor network to initialize messages for. | ||
fill_fn : callable, optional | ||
A function to fill the messages with, of signature ``fill_fn(shape)``. | ||
smudge_factor : float, optional | ||
A small number to add to the messages to avoid numerical issues. | ||
Returns | ||
------- | ||
messages : dict | ||
The initial messages. For every index and tensor id pair, there will | ||
be a message to and from with keys ``(ix, tid)`` and ``(tid, ix)``. | ||
""" | ||
from quimb.tensor.contraction import array_contract | ||
|
||
backend = ar.infer_backend(next(t.data for t in tn)) | ||
_sum = ar.get_lib_fn(backend, "sum") | ||
|
||
messages = {} | ||
|
||
# compute first messages from tensors to indices | ||
for tid, t in tn.tensor_map.items(): | ||
k_inputs = tuple(range(t.ndim)) | ||
for i, ix in enumerate(t.inds): | ||
if fill_fn is None: | ||
# sum over all other indices to get initial message | ||
m = array_contract( | ||
arrays=(t.data,), | ||
inputs=(k_inputs,), | ||
output=(i,), | ||
) | ||
# normalize and insert | ||
messages[tid, ix] = m / _sum(m) | ||
else: | ||
d = t.ind_size(ix) | ||
m = fill_fn((d,)) | ||
messages[tid, ix] = m / _sum(m) | ||
|
||
# compute first messages from indices to tensors | ||
for ix, tids in tn.ind_map.items(): | ||
ms = [messages[tid, ix] for tid in tids] | ||
mp = prod(ms) | ||
for mi, tid in zip(ms, tids): | ||
m = mp / (mi + smudge_factor) | ||
# normalize and insert | ||
messages[ix, tid] = m / _sum(m) | ||
|
||
return messages | ||
|
||
|
||
def combine_local_contractions( | ||
tvals, | ||
mvals, | ||
backend, | ||
strip_exponent=False, | ||
check_for_zero=True, | ||
): | ||
_abs = ar.get_lib_fn(backend, "abs") | ||
_log10 = ar.get_lib_fn(backend, "log10") | ||
|
||
mantissa = 1 | ||
exponent = 0 | ||
for vt in tvals: | ||
avt = _abs(vt) | ||
|
||
if check_for_zero and (avt == 0.0): | ||
if strip_exponent: | ||
return 0.0, 0.0 | ||
else: | ||
return 0.0 | ||
|
||
mantissa = mantissa * (vt / avt) | ||
exponent = exponent + _log10(avt) | ||
for mt in mvals: | ||
amt = _abs(mt) | ||
mantissa = mantissa / (mt / amt) | ||
exponent = exponent - _log10(amt) | ||
|
||
if strip_exponent: | ||
return mantissa, exponent | ||
else: | ||
return mantissa * 10**exponent | ||
|
||
|
||
def maybe_get_thread_pool(thread_pool): | ||
"""Get a thread pool if requested.""" | ||
if thread_pool is False: | ||
return None | ||
|
||
if thread_pool is True: | ||
import quimb as qu | ||
|
||
return qu.get_thread_pool() | ||
|
||
if isinstance(thread_pool, int): | ||
import quimb as qu | ||
|
||
return qu.get_thread_pool(thread_pool) | ||
|
||
return thread_pool | ||
|
||
|
||
def create_lazy_community_edge_map(tn, site_tags=None, rank_simplify=True): | ||
"""For lazy BP algorithms, create the data structures describing the | ||
effective graph of the lazily grouped 'sites' given by ``site_tags``. | ||
""" | ||
if site_tags is None: | ||
site_tags = set(tn.site_tags) | ||
else: | ||
site_tags = set(site_tags) | ||
|
||
edges = {} | ||
neighbors = {} | ||
local_tns = {} | ||
touch_map = {} | ||
|
||
for ix in tn.ind_map: | ||
ts = tn._inds_get(ix) | ||
tags = {tag for t in ts for tag in t.tags if tag in site_tags} | ||
if len(tags) >= 2: | ||
i, j = tuple(sorted(tags)) | ||
|
||
if (i, j) in edges: | ||
# already processed this edge | ||
continue | ||
|
||
# add to neighbor map | ||
neighbors.setdefault(i, []).append(j) | ||
neighbors.setdefault(j, []).append(i) | ||
|
||
# get local TNs and compute bonds between them, | ||
# rank simplify here also to prepare for contractions | ||
try: | ||
tn_i = local_tns[i] | ||
except KeyError: | ||
tn_i = local_tns[i] = tn.select(i, virtual=False) | ||
if rank_simplify: | ||
tn_i.rank_simplify_() | ||
try: | ||
tn_j = local_tns[j] | ||
except KeyError: | ||
tn_j = local_tns[j] = tn.select(j, virtual=False) | ||
if rank_simplify: | ||
tn_j.rank_simplify_() | ||
|
||
edges[i, j] = tuple(qtn.bonds(tn_i, tn_j)) | ||
|
||
for i, j in edges: | ||
touch_map[(i, j)] = tuple((j, k) for k in neighbors[j] if k != i) | ||
touch_map[(j, i)] = tuple((i, k) for k in neighbors[i] if k != j) | ||
|
||
if len(local_tns) != len(site_tags): | ||
# handle potentially disconnected sites | ||
for i in sorted(site_tags): | ||
try: | ||
tn_i = local_tns[i] = tn.select(i, virtual=False) | ||
if rank_simplify: | ||
tn_i.rank_simplify_() | ||
except KeyError: | ||
pass | ||
|
||
return edges, neighbors, local_tns, touch_map |
Oops, something went wrong.