Skip to content

Commit 028e75f

Browse files
committed
L2BP: add update="sequential"
1 parent c8ea241 commit 028e75f

File tree

4 files changed

+47
-13
lines changed

4 files changed

+47
-13
lines changed

quimb/experimental/belief_propagation/l1bp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,16 @@ def _update_m(key, data):
170170

171171
if self.update == "parallel":
172172
new_data = {}
173+
# compute all new messages
173174
while self.touched:
174175
key = self.touched.pop()
175176
new_data[key] = _compute_m(key)
177+
# insert all new messages
176178
for key, data in new_data.items():
177179
_update_m(key, data)
178180

179181
elif self.update == "sequential":
182+
# compute each new message and immediately re-insert it
180183
while self.touched:
181184
key = self.touched.pop()
182185
data = _compute_m(key)

quimb/experimental/belief_propagation/l2bp.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ def __init__(
1919
site_tags=None,
2020
damping=0.0,
2121
local_convergence=True,
22+
update="parallel",
2223
optimize="auto-hq",
2324
**contract_opts,
2425
):
2526
self.backend = next(t.backend for t in tn)
2627
self.damping = damping
2728
self.local_convergence = local_convergence
29+
self.update = update
2830
self.optimize = optimize
2931
self.contract_opts = contract_opts
3032

@@ -126,10 +128,12 @@ def iterate(self, tol=5e-6):
126128
)
127129

128130
ncheck = len(self.touched)
131+
nconv = 0
132+
max_mdiff = -1.0
133+
new_touched = set()
129134

130-
new_data = {}
131-
while self.touched:
132-
i, j = self.touched.pop()
135+
def _compute_m(key):
136+
i, j = key
133137
bix = self.edges[(i, j) if i < j else (j, i)]
134138
cix = tuple(ix + "**" for ix in bix)
135139
output_inds = cix + bix
@@ -145,12 +149,11 @@ def iterate(self, tol=5e-6):
145149
)
146150
tm_new.modify(apply=self._symmetrize)
147151
tm_new.modify(apply=self._normalize)
148-
# defer setting the data to do a parallel update
149-
new_data[i, j] = tm_new.data
152+
return tm_new.data
153+
154+
def _update_m(key, data):
155+
nonlocal nconv, max_mdiff
150156

151-
nconv = 0
152-
max_mdiff = -1.0
153-
for key, data in new_data.items():
154157
tm = self.messages[key]
155158

156159
if self.damping > 0.0:
@@ -160,13 +163,32 @@ def iterate(self, tol=5e-6):
160163

161164
if mdiff > tol:
162165
# mark touching messages for update
163-
self.touched.update(self.touch_map[key])
166+
new_touched.update(self.touch_map[key])
164167
else:
165168
nconv += 1
166169

167170
max_mdiff = max(max_mdiff, mdiff)
168171
tm.modify(data=data)
169172

173+
if self.update == "parallel":
174+
new_data = {}
175+
# compute all new messages
176+
while self.touched:
177+
key = self.touched.pop()
178+
new_data[key] = _compute_m(key)
179+
# insert all new messages
180+
for key, data in new_data.items():
181+
_update_m(key, data)
182+
183+
elif self.update == "sequential":
184+
# compute each new message and immediately re-insert it
185+
while self.touched:
186+
key = self.touched.pop()
187+
data = _compute_m(key)
188+
_update_m(key, data)
189+
190+
self.touched = new_touched
191+
170192
return nconv, ncheck, max_mdiff
171193

172194
def contract(self, strip_exponent=False):

tests/test_tensor/test_belief_propagation/test_l1bp.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,15 @@ def test_contract_loopy_approx(dtype, damping):
2929

3030
@pytest.mark.parametrize("dtype", ["float32", "complex64"])
3131
@pytest.mark.parametrize("damping", [0.0, 0.1])
32-
def test_contract_double_loopy_approx(dtype, damping):
32+
@pytest.mark.parametrize("update", ("parallel", "sequential"))
33+
def test_contract_double_loopy_approx(dtype, damping, update):
3334
peps = qtn.PEPS.rand(4, 3, 2, seed=42, dtype=dtype)
3435
tn = peps.H & peps
3536
Z_ex = tn.contract()
3637
info = {}
37-
Z_bp1 = contract_l1bp(tn, damping=damping, info=info, progbar=True)
38+
Z_bp1 = contract_l1bp(
39+
tn, damping=damping, update=update, info=info, progbar=True
40+
)
3841
assert info["converged"]
3942
assert Z_bp1 == pytest.approx(Z_ex, rel=0.3)
4043
# compare with 2-norm BP on the peps directly

tests/test_tensor/test_belief_propagation/test_l2bp.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def test_contract_double_layer_tree_exact(dtype):
6969

7070
@pytest.mark.parametrize("dtype", ["float32", "complex64"])
7171
@pytest.mark.parametrize("damping", [0.0, 0.1])
72-
def test_compress_double_layer_loopy(dtype, damping):
72+
@pytest.mark.parametrize("update", ["parallel", "sequential"])
73+
def test_compress_double_layer_loopy(dtype, damping, update):
7374
peps = qtn.PEPS.rand(3, 4, bond_dim=3, seed=42, dtype=dtype)
7475
pepo = qtn.PEPO.rand(3, 4, bond_dim=2, seed=42, dtype=dtype)
7576

@@ -85,7 +86,12 @@ def test_compress_double_layer_loopy(dtype, damping):
8586
# compress using BP
8687
info = {}
8788
tn_bp = compress_l2bp(
88-
tn_lazy, max_bond=3, damping=damping, info=info, progbar=True
89+
tn_lazy,
90+
max_bond=3,
91+
damping=damping,
92+
update=update,
93+
info=info,
94+
progbar=True,
8995
)
9096
assert info["converged"]
9197
assert tn_bp.num_tensors == 12

0 commit comments

Comments
 (0)