Skip to content

Commit 737c22d

Browse files
committed
Replace safe_zip by built-in zip function with strict=True
1 parent 7445798 commit 737c22d

File tree

7 files changed

+29
-28
lines changed

7 files changed

+29
-28
lines changed

varipeps/ctmrg/routine.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -771,8 +771,10 @@ def _ctmrg_rev_while_body(carry):
771771
bar_fixed_point = bar_fixed_point_last_step.replace_unique_tensors(
772772
[
773773
t_old.__add__(t_new, checks=False)
774-
for t_old, t_new in jax.util.safe_zip(
775-
initial_bar.get_unique_tensors(), new_env_bar.get_unique_tensors()
774+
for t_old, t_new in zip(
775+
initial_bar.get_unique_tensors(),
776+
new_env_bar.get_unique_tensors(),
777+
strict=True,
776778
)
777779
]
778780
)

varipeps/mapping/florett_pentagon.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def _calc_onsite_gate(
579579
Id_other_sites = jnp.eye(d**7)
580580

581581
for i, (b_e, g_e, blue_e) in enumerate(
582-
jax.util.safe_zip(black_gates, green_gates, blue_gates)
582+
zip(black_gates, green_gates, blue_gates, strict=True)
583583
):
584584
black_12 = jnp.kron(b_e, Id_other_sites)
585585

@@ -673,7 +673,7 @@ def _calc_right_gate(
673673

674674
Id_other_site = jnp.eye(d)
675675

676-
for i, (b_e, blue_e) in enumerate(jax.util.safe_zip(black_gates, blue_gates)):
676+
for i, (b_e, blue_e) in enumerate(zip(black_gates, blue_gates, strict=True)):
677677
black_51 = jnp.kron(b_e, Id_other_site)
678678

679679
blue_59 = jnp.kron(blue_e, Id_other_site)
@@ -701,7 +701,7 @@ def _calc_down_gate(
701701

702702
Id_other_site = jnp.eye(d**2)
703703

704-
for i, (b_e, blue_e) in enumerate(jax.util.safe_zip(black_gates, blue_gates)):
704+
for i, (b_e, blue_e) in enumerate(zip(black_gates, blue_gates, strict=True)):
705705
black_91 = jnp.kron(b_e, Id_other_site)
706706
black_91 = black_91.reshape(d, d, d, d, d, d, d, d)
707707
black_91 = black_91.transpose(2, 0, 1, 3, 6, 4, 5, 7)
@@ -1157,11 +1157,12 @@ def __call__(
11571157
)
11581158

11591159
for sr_i, (sr_o, sr_h, sr_v, sr_d) in enumerate(
1160-
jax.util.safe_zip(
1160+
zip(
11611161
step_result_onsite[: len(self.black_gates)],
11621162
step_result_horizontal[: len(self.black_gates)],
11631163
step_result_vertical[: len(self.black_gates)],
11641164
step_result_diagonal[: len(self.black_gates)],
1165+
strict=True,
11651166
)
11661167
):
11671168
result[sr_i] += sr_o + sr_h + sr_v + sr_d

varipeps/mapping/maple_leaf.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def _calc_onsite_gate(
262262
single_gates = [None] * result_length
263263

264264
for i, (g_e, b_e, r_e) in enumerate(
265-
jax.util.safe_zip(green_gates, blue_gates, red_gates)
265+
zip(green_gates, blue_gates, red_gates, strict=True)
266266
):
267267
(
268268
green_12,
@@ -327,7 +327,7 @@ def _calc_right_gate(
327327

328328
single_gates = [None] * result_length
329329

330-
for i, (b_e, r_e) in enumerate(jax.util.safe_zip(blue_gates, red_gates)):
330+
for i, (b_e, r_e) in enumerate(zip(blue_gates, red_gates, strict=True)):
331331
red_61, blue_62 = get_right_gates(b_e, r_e, d)
332332

333333
result[i] = red_61 + blue_62
@@ -361,7 +361,7 @@ def _calc_down_gate(
361361

362362
single_gates = [None] * result_length
363363

364-
for i, (b_e, r_e) in enumerate(jax.util.safe_zip(blue_gates, red_gates)):
364+
for i, (b_e, r_e) in enumerate(zip(blue_gates, red_gates, strict=True)):
365365
blue_35, red_36 = get_down_gates(b_e, r_e, d)
366366

367367
result[i] = blue_35 + red_36
@@ -395,7 +395,7 @@ def _calc_diagonal_gate(
395395

396396
single_gates = [None] * result_length
397397

398-
for i, (b_e, r_e) in enumerate(jax.util.safe_zip(blue_gates, red_gates)):
398+
for i, (b_e, r_e) in enumerate(zip(blue_gates, red_gates, strict=True)):
399399
blue_41, red_31 = get_diagonal_gates(b_e, r_e, d)
400400

401401
result[i] = blue_41 + red_31
@@ -802,11 +802,12 @@ def __call__(
802802
)
803803

804804
for sr_i, (sr_o, sr_h, sr_v, sr_d) in enumerate(
805-
jax.util.safe_zip(
805+
zip(
806806
step_result_onsite[: len(self.green_gates)],
807807
step_result_horizontal[: len(self.green_gates)],
808808
step_result_vertical[: len(self.green_gates)],
809809
step_result_diagonal[: len(self.green_gates)],
810+
strict=True,
810811
)
811812
):
812813
result[sr_i] += sr_o + sr_h + sr_v + sr_d

varipeps/mapping/square_kagome.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ def __call__(
769769
)
770770

771771
for sr_i, (sr_h, sr_v) in enumerate(
772-
jax.util.safe_zip(step_result_horizontal, step_result_vertical)
772+
zip(step_result_horizontal, step_result_vertical, strict=True)
773773
):
774774
result[sr_i] += sr_h + sr_v
775775
elif len(self.cross_gates) > 0:

varipeps/mapping/triangular.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,7 @@ def __call__(
232232
)
233233

234234
for sr_i, (sr_x, sr_y, sr_diagonal) in enumerate(
235-
jax.util.safe_zip(
236-
step_result_x, step_result_y, step_result_diagonal
237-
)
235+
zip(step_result_x, step_result_y, step_result_diagonal, strict=True)
238236
):
239237
result[sr_i] += sr_x + sr_y + sr_diagonal
240238

@@ -633,7 +631,7 @@ def _calc_quadrat_gate_next_nearest(
633631
Id_other_sites = jnp.eye(d**2)
634632

635633
for i, (n_e, n_n_e) in enumerate(
636-
jax.util.safe_zip(nearest_gates, next_nearest_gates)
634+
zip(nearest_gates, next_nearest_gates, strict=True)
637635
):
638636
nearest_34 = jnp.kron(Id_other_sites, n_e)
639637

@@ -913,10 +911,11 @@ def __call__(
913911
)
914912

915913
for sr_i, (sr_q, sr_h, sr_v) in enumerate(
916-
jax.util.safe_zip(
914+
zip(
917915
step_result_quadrat[: len(self.nearest_neighbor_gates)],
918916
step_result_horizontal_rect[: len(self.nearest_neighbor_gates)],
919917
step_result_vertical_rect[: len(self.nearest_neighbor_gates)],
918+
strict=True,
920919
)
921920
):
922921
result[sr_i] += sr_q + sr_h + sr_v

varipeps/optimization/inner_function.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import jax.numpy as jnp
22
from jax import value_and_grad
3-
from jax.util import safe_zip
43

54
from varipeps import varipeps_config
65
from varipeps.peps import PEPS_Unit_Cell
@@ -39,7 +38,7 @@ def _map_tensors(
3938
old_tensors = unitcell.get_unique_tensors()
4039
if not all(
4140
jnp.allclose(ti, tj_obj.tensor)
42-
for ti, tj_obj in safe_zip(peps_tensors, old_tensors)
41+
for ti, tj_obj in zip(peps_tensors, old_tensors, strict=True)
4342
):
4443
raise ValueError(
4544
"Input tensors and provided unitcell are not the same state."
@@ -110,14 +109,14 @@ def calc_ctmrg_expectation(
110109
spiral_vectors_x = (spiral_vectors_x,)
111110
spiral_vectors = tuple(
112111
jnp.array((sx, sy))
113-
for sx, sy in safe_zip(spiral_vectors_x, spiral_vectors)
112+
for sx, sy in zip(spiral_vectors_x, spiral_vectors, strict=True)
114113
)
115114
elif spiral_vectors_y is not None:
116115
if isinstance(spiral_vectors_y, jnp.ndarray):
117116
spiral_vectors_y = (spiral_vectors_y,)
118117
spiral_vectors = tuple(
119118
jnp.array((sx, sy))
120-
for sx, sy in safe_zip(spiral_vectors, spiral_vectors_y)
119+
for sx, sy in zip(spiral_vectors, spiral_vectors_y, strict=True)
121120
)
122121
else:
123122
peps_tensors, unitcell = _map_tensors(
@@ -211,14 +210,14 @@ def calc_preconverged_ctmrg_value_and_grad(
211210
spiral_vectors_x = (spiral_vectors_x,)
212211
spiral_vectors = tuple(
213212
jnp.array((sx, sy))
214-
for sx, sy in safe_zip(spiral_vectors_x, spiral_vectors)
213+
for sx, sy in zip(spiral_vectors_x, spiral_vectors, strict=True)
215214
)
216215
elif spiral_vectors_y is not None:
217216
if isinstance(spiral_vectors_y, jnp.ndarray):
218217
spiral_vectors_y = (spiral_vectors_y,)
219218
spiral_vectors = tuple(
220219
jnp.array((sx, sy))
221-
for sx, sy in safe_zip(spiral_vectors, spiral_vectors_y)
220+
for sx, sy in zip(spiral_vectors, spiral_vectors_y, strict=True)
222221
)
223222
else:
224223
peps_tensors, unitcell = _map_tensors(
@@ -293,14 +292,14 @@ def calc_ctmrg_expectation_custom(
293292
spiral_vectors_x = (spiral_vectors_x,)
294293
spiral_vectors = tuple(
295294
jnp.array((sx, sy))
296-
for sx, sy in safe_zip(spiral_vectors_x, spiral_vectors)
295+
for sx, sy in zip(spiral_vectors_x, spiral_vectors, strict=True)
297296
)
298297
elif spiral_vectors_y is not None:
299298
if isinstance(spiral_vectors_y, jnp.ndarray):
300299
spiral_vectors_y = (spiral_vectors_y,)
301300
spiral_vectors = tuple(
302301
jnp.array((sx, sy))
303-
for sx, sy in safe_zip(spiral_vectors, spiral_vectors_y)
302+
for sx, sy in zip(spiral_vectors, spiral_vectors_y, strict=True)
304303
)
305304
else:
306305
peps_tensors, unitcell = _map_tensors(

varipeps/optimization/optimizer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import jax.numpy as jnp
1414
from jax.lax import scan
1515
from jax.flatten_util import ravel_pytree
16-
from jax.util import safe_zip
1716

1817
from varipeps import varipeps_config, varipeps_global_state
1918
from varipeps.config import Optimizing_Methods
@@ -240,14 +239,14 @@ def _autosave_wrapper(
240239
spiral_vectors_x = (spiral_vectors_x,)
241240
spiral_vectors = tuple(
242241
jnp.array((sx, sy))
243-
for sx, sy in safe_zip(spiral_vectors_x, spiral_vectors)
242+
for sx, sy in zip(spiral_vectors_x, spiral_vectors, strict=True)
244243
)
245244
elif spiral_vectors_y is not None:
246245
if isinstance(spiral_vectors_y, jnp.ndarray):
247246
spiral_vectors_y = (spiral_vectors_y,)
248247
spiral_vectors = tuple(
249248
jnp.array((sx, sy))
250-
for sx, sy in safe_zip(spiral_vectors, spiral_vectors_y)
249+
for sx, sy in zip(spiral_vectors, spiral_vectors_y, strict=True)
251250
)
252251
elif additional_input.get("spiral_vectors") is not None:
253252
spiral_vectors = additional_input.get("spiral_vectors")

0 commit comments

Comments
 (0)