-
Notifications
You must be signed in to change notification settings - Fork 112
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
358 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
import autoray as ar | ||
|
||
import quimb.tensor as qtn | ||
|
||
from .bp_common import ( | ||
BeliefPropagationCommon, | ||
create_lazy_community_edge_map, | ||
combine_local_contractions, | ||
) | ||
|
||
|
||
class L1BP(BeliefPropagationCommon): | ||
"""Lazy 1-norm belief propagation. BP is run between groups of tensors | ||
defined by ``site_tags``. The message updates are lazy contractions. | ||
Parameters | ||
---------- | ||
tn : TensorNetwork | ||
The tensor network to run BP on. | ||
site_tags : sequence of str, optional | ||
The tags identifying the sites in ``tn``, each tag forms a region, | ||
which should not overlap. If the tensor network is structured, then | ||
these are inferred automatically. | ||
damping : float, optional | ||
The damping parameter to use, defaults to no damping. | ||
local_convergence : bool, optional | ||
Whether to allow messages to locally converge - i.e. if all their | ||
input messages have converged then stop updating them. | ||
optimize : str or PathOptimizer, optional | ||
The path optimizer to use when contracting the messages. | ||
contract_opts | ||
Other options supplied to ``cotengra.array_contract``. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
tn, | ||
site_tags=None, | ||
damping=0.0, | ||
local_convergence=True, | ||
optimize="auto-hq", | ||
message_init_function=None, | ||
**contract_opts, | ||
): | ||
self.backend = next(t.backend for t in tn) | ||
self.damping = damping | ||
self.local_convergence = local_convergence | ||
self.optimize = optimize | ||
self.contract_opts = contract_opts | ||
|
||
if site_tags is None: | ||
self.site_tags = tuple(tn.site_tags) | ||
else: | ||
self.site_tags = tuple(site_tags) | ||
|
||
( | ||
self.edges, | ||
self.neighbors, | ||
self.local_tns, | ||
self.touch_map, | ||
) = create_lazy_community_edge_map(tn, site_tags) | ||
self.touched = set() | ||
|
||
self._abs = ar.get_lib_fn(self.backend, "abs") | ||
self._max = ar.get_lib_fn(self.backend, "max") | ||
self._sum = ar.get_lib_fn(self.backend, "sum") | ||
_real = ar.get_lib_fn(self.backend, "real") | ||
_argmax = ar.get_lib_fn(self.backend, "argmax") | ||
_reshape = ar.get_lib_fn(self.backend, "reshape") | ||
self._norm = ar.get_lib_fn(self.backend, "linalg.norm") | ||
|
||
def _normalize(x): | ||
return x / self._sum(x) | ||
# return x / self._norm(x) | ||
# return x / self._max(x) | ||
# fx = _reshape(x, (-1,)) | ||
# return x / fx[_argmax(self._abs(_real(fx)))] | ||
|
||
def _distance(x, y): | ||
return self._sum(self._abs(x - y)) | ||
|
||
self._normalize = _normalize | ||
self._distance = _distance | ||
|
||
# for each meta bond create initial messages | ||
self.messages = {} | ||
for pair, bix in self.edges.items(): | ||
# compute leftwards and rightwards messages | ||
for i, j in (sorted(pair), sorted(pair, reverse=True)): | ||
tn_i = self.local_tns[i] | ||
# initial message just sums over dangling bonds | ||
|
||
if message_init_function is None: | ||
tm = tn_i.contract( | ||
output_inds=bix, | ||
optimize=self.optimize, | ||
drop_tags=True, | ||
**self.contract_opts, | ||
) | ||
# normalize | ||
tm.modify(apply=self._normalize) | ||
else: | ||
shape = tuple(tn_i.ind_size(ix) for ix in bix) | ||
tm = qtn.Tensor( | ||
data=message_init_function(shape), | ||
inds=bix, | ||
) | ||
|
||
self.messages[i, j] = tm | ||
|
||
# compute the contractions | ||
self.contraction_tns = {} | ||
for pair, bix in self.edges.items(): | ||
# for each meta bond compute left and right contractions | ||
for i, j in (sorted(pair), sorted(pair, reverse=True)): | ||
tn_i = self.local_tns[i].copy() | ||
# attach incoming messages to dangling bonds | ||
tks = [ | ||
self.messages[k, i] for k in self.neighbors[i] if k != j | ||
] | ||
# virtual so we can modify messages tensors inplace | ||
tn_i_to_j = qtn.TensorNetwork((tn_i, *tks), virtual=True) | ||
self.contraction_tns[i, j] = tn_i_to_j | ||
|
||
def iterate(self, tol=5e-6): | ||
if (not self.local_convergence) or (not self.touched): | ||
# assume if asked to iterate that we want to check all messages | ||
self.touched.update( | ||
pair for edge in self.edges for pair in (edge, edge[::-1]) | ||
) | ||
|
||
ncheck = len(self.touched) | ||
new_data = {} | ||
while self.touched: | ||
i, j = self.touched.pop() | ||
|
||
bix = self.edges[(i, j) if i < j else (j, i)] | ||
tn_i_to_j = self.contraction_tns[i, j] | ||
tm_new = tn_i_to_j.contract( | ||
output_inds=bix, | ||
optimize=self.optimize, | ||
**self.contract_opts, | ||
) | ||
m = self._normalize(tm_new.data) | ||
|
||
new_data[i, j] = m | ||
|
||
nconv = 0 | ||
max_mdiff = -1.0 | ||
for key, data in new_data.items(): | ||
tm = self.messages[key] | ||
|
||
if self.damping != 0.0: | ||
data = (1 - self.damping) * data + self.damping * tm.data | ||
|
||
mdiff = float(self._distance(tm.data, data)) | ||
|
||
if mdiff > tol: | ||
# mark touching messages for update | ||
self.touched.update(self.touch_map[key]) | ||
else: | ||
nconv += 1 | ||
|
||
max_mdiff = max(max_mdiff, mdiff) | ||
tm.modify(data=data) | ||
|
||
return nconv, ncheck, max_mdiff | ||
|
||
def contract(self, strip_exponent=False): | ||
tvals = [] | ||
for site, tn_ic in self.local_tns.items(): | ||
if site in self.neighbors: | ||
tval = qtn.tensor_contract( | ||
*tn_ic, | ||
*(self.messages[k, site] for k in self.neighbors[site]), | ||
optimize=self.optimize, | ||
**self.contract_opts, | ||
) | ||
else: | ||
# site exists but has no neighbors | ||
tval = tn_ic.contract( | ||
output_inds=(), | ||
optimize=self.optimize, | ||
**self.contract_opts, | ||
) | ||
tvals.append(tval) | ||
|
||
mvals = [] | ||
for i, j in self.edges: | ||
mval = qtn.tensor_contract( | ||
self.messages[i, j], | ||
self.messages[j, i], | ||
optimize=self.optimize, | ||
**self.contract_opts, | ||
) | ||
mvals.append(mval) | ||
|
||
return combine_local_contractions( | ||
tvals, mvals, self.backend, strip_exponent=strip_exponent | ||
) | ||
|
||
|
||
def contract_l1bp( | ||
tn, | ||
max_iterations=1000, | ||
tol=5e-6, | ||
site_tags=None, | ||
damping=0.0, | ||
local_convergence=True, | ||
optimize="auto-hq", | ||
strip_exponent=False, | ||
progbar=False, | ||
**contract_opts, | ||
): | ||
"""Estimate the contraction of ``tn`` using lazy 1-norm belief propagation. | ||
Parameters | ||
---------- | ||
tn : TensorNetwork | ||
The tensor network to contract. | ||
max_iterations : int, optional | ||
The maximum number of iterations to perform. | ||
tol : float, optional | ||
The convergence tolerance for messages. | ||
site_tags : sequence of str, optional | ||
The tags identifying the sites in ``tn``, each tag forms a region. If | ||
the tensor network is structured, then these are inferred | ||
automatically. | ||
damping : float, optional | ||
The damping parameter to use, defaults to no damping. | ||
local_convergence : bool, optional | ||
Whether to allow messages to locally converge - i.e. if all their | ||
input messages have converged then stop updating them. | ||
optimize : str or PathOptimizer, optional | ||
The path optimizer to use when contracting the messages. | ||
progbar : bool, optional | ||
Whether to show a progress bar. | ||
strip_exponent : bool, optional | ||
Whether to strip the exponent from the final result. If ``True`` | ||
then the returned result is ``(mantissa, exponent)``. | ||
contract_opts | ||
Other options supplied to ``cotengra.array_contract``. | ||
""" | ||
bp = L1BP( | ||
tn, | ||
site_tags=site_tags, | ||
damping=damping, | ||
local_convergence=local_convergence, | ||
optimize=optimize, | ||
**contract_opts, | ||
) | ||
bp.run( | ||
max_iterations=max_iterations, | ||
tol=tol, | ||
progbar=progbar, | ||
) | ||
return bp.contract( | ||
strip_exponent=strip_exponent, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import pytest | ||
|
||
import quimb as qu | ||
import quimb.tensor as qtn | ||
from quimb.experimental.belief_propagation.l1bp import contract_l1bp | ||
from quimb.experimental.belief_propagation.d2bp import contract_d2bp | ||
|
||
|
||
@pytest.mark.parametrize("dtype", ["float32", "complex64"]) | ||
def test_contract_tree_exact(dtype): | ||
tn = qtn.TN_rand_tree(10, 3, seed=42, dtype=dtype) | ||
Z_ex = tn.contract() | ||
Z_bp = contract_l1bp(tn) | ||
assert Z_ex == pytest.approx(Z_bp, rel=5e-6) | ||
|
||
|
||
@pytest.mark.parametrize("dtype", ["float32", "complex64"]) | ||
@pytest.mark.parametrize("damping", [0.0, 0.1]) | ||
def test_contract_loopy_approx(dtype, damping): | ||
tn = qtn.TN2D_rand(3, 4, 5, dtype=dtype, dist="uniform") | ||
Z_ex = tn.contract() | ||
Z_bp = contract_l1bp(tn, damping=damping) | ||
assert Z_ex == pytest.approx(Z_bp, rel=0.1) | ||
|
||
|
||
@pytest.mark.parametrize("dtype", ["float32", "complex64"]) | ||
@pytest.mark.parametrize("damping", [0.0, 0.1]) | ||
def test_contract_double_loopy_approx(dtype, damping): | ||
peps = qtn.PEPS.rand(4, 3, 2, seed=42, dtype=dtype) | ||
tn = peps.H & peps | ||
Z_ex = tn.contract() | ||
Z_bp1 = contract_l1bp(tn, damping=damping) | ||
assert Z_bp1 == pytest.approx(Z_ex, rel=0.3) | ||
# compare with 2-norm BP on the peps directly | ||
Z_bp2 = contract_d2bp(peps) | ||
assert Z_bp1 == pytest.approx(Z_bp2, rel=5e-6) | ||
|
||
|
||
@pytest.mark.parametrize("dtype", ["float32", "complex64"]) | ||
def test_contract_tree_triple_sandwich_exact(dtype): | ||
edges = qtn.edges_tree_rand(20, 3, seed=42) | ||
ket = qtn.TN_from_edges_rand( | ||
edges, | ||
3, | ||
phys_dim=2, | ||
seed=42, | ||
site_ind_id="k{}", | ||
dtype=dtype, | ||
) | ||
op = qtn.TN_from_edges_rand( | ||
edges, | ||
2, | ||
phys_dim=2, | ||
seed=42, | ||
site_ind_id=("k{}", "b{}"), | ||
dtype=dtype, | ||
) | ||
bra = qtn.TN_from_edges_rand( | ||
edges, | ||
3, | ||
phys_dim=2, | ||
seed=42, | ||
site_ind_id="b{}", | ||
dtype=dtype, | ||
) | ||
tn = bra.H | op | ket | ||
Z_ex = tn.contract() | ||
Z_bp = contract_l1bp(tn) | ||
assert Z_ex == pytest.approx(Z_bp, rel=5e-6) | ||
|
||
@pytest.mark.parametrize("dtype", ["float32", "complex64"]) | ||
@pytest.mark.parametrize("damping", [0.0, 0.1]) | ||
def test_contract_tree_triple_sandwich_loopy_approx(dtype, damping): | ||
edges = qtn.edges_2d_hexagonal(2, 3) | ||
ket = qtn.TN_from_edges_rand( | ||
edges, | ||
3, | ||
phys_dim=2, | ||
seed=42, | ||
site_ind_id="k{}", | ||
dtype=dtype, | ||
# make the wavefunction postive to make easier | ||
dist='uniform', | ||
) | ||
ket /= (ket.H @ ket)**0.5 | ||
|
||
G_ket = ket.gate(qu.pauli('Z'), [(1, 1, 'A')], propagate_tags="sites") | ||
tn = ket.H | G_ket | ||
Z_ex = tn.contract() | ||
Z_bp = contract_l1bp(tn, damping=damping) | ||
assert Z_bp == pytest.approx(Z_ex, rel=0.5) |