|
1 | 1 | import jax.numpy as jnp
|
2 | 2 | from jax import value_and_grad
|
3 |
| -from jax.util import safe_zip |
4 | 3 |
|
5 | 4 | from varipeps import varipeps_config
|
6 | 5 | from varipeps.peps import PEPS_Unit_Cell
|
@@ -39,7 +38,7 @@ def _map_tensors(
|
39 | 38 | old_tensors = unitcell.get_unique_tensors()
|
40 | 39 | if not all(
|
41 | 40 | jnp.allclose(ti, tj_obj.tensor)
|
42 |
| - for ti, tj_obj in safe_zip(peps_tensors, old_tensors) |
| 41 | + for ti, tj_obj in zip(peps_tensors, old_tensors, strict=True) |
43 | 42 | ):
|
44 | 43 | raise ValueError(
|
45 | 44 | "Input tensors and provided unitcell are not the same state."
|
@@ -110,14 +109,14 @@ def calc_ctmrg_expectation(
|
110 | 109 | spiral_vectors_x = (spiral_vectors_x,)
|
111 | 110 | spiral_vectors = tuple(
|
112 | 111 | jnp.array((sx, sy))
|
113 |
| - for sx, sy in safe_zip(spiral_vectors_x, spiral_vectors) |
| 112 | + for sx, sy in zip(spiral_vectors_x, spiral_vectors, strict=True) |
114 | 113 | )
|
115 | 114 | elif spiral_vectors_y is not None:
|
116 | 115 | if isinstance(spiral_vectors_y, jnp.ndarray):
|
117 | 116 | spiral_vectors_y = (spiral_vectors_y,)
|
118 | 117 | spiral_vectors = tuple(
|
119 | 118 | jnp.array((sx, sy))
|
120 |
| - for sx, sy in safe_zip(spiral_vectors, spiral_vectors_y) |
| 119 | + for sx, sy in zip(spiral_vectors, spiral_vectors_y, strict=True) |
121 | 120 | )
|
122 | 121 | else:
|
123 | 122 | peps_tensors, unitcell = _map_tensors(
|
@@ -211,14 +210,14 @@ def calc_preconverged_ctmrg_value_and_grad(
|
211 | 210 | spiral_vectors_x = (spiral_vectors_x,)
|
212 | 211 | spiral_vectors = tuple(
|
213 | 212 | jnp.array((sx, sy))
|
214 |
| - for sx, sy in safe_zip(spiral_vectors_x, spiral_vectors) |
| 213 | + for sx, sy in zip(spiral_vectors_x, spiral_vectors, strict=True) |
215 | 214 | )
|
216 | 215 | elif spiral_vectors_y is not None:
|
217 | 216 | if isinstance(spiral_vectors_y, jnp.ndarray):
|
218 | 217 | spiral_vectors_y = (spiral_vectors_y,)
|
219 | 218 | spiral_vectors = tuple(
|
220 | 219 | jnp.array((sx, sy))
|
221 |
| - for sx, sy in safe_zip(spiral_vectors, spiral_vectors_y) |
| 220 | + for sx, sy in zip(spiral_vectors, spiral_vectors_y, strict=True) |
222 | 221 | )
|
223 | 222 | else:
|
224 | 223 | peps_tensors, unitcell = _map_tensors(
|
@@ -293,14 +292,14 @@ def calc_ctmrg_expectation_custom(
|
293 | 292 | spiral_vectors_x = (spiral_vectors_x,)
|
294 | 293 | spiral_vectors = tuple(
|
295 | 294 | jnp.array((sx, sy))
|
296 |
| - for sx, sy in safe_zip(spiral_vectors_x, spiral_vectors) |
| 295 | + for sx, sy in zip(spiral_vectors_x, spiral_vectors, strict=True) |
297 | 296 | )
|
298 | 297 | elif spiral_vectors_y is not None:
|
299 | 298 | if isinstance(spiral_vectors_y, jnp.ndarray):
|
300 | 299 | spiral_vectors_y = (spiral_vectors_y,)
|
301 | 300 | spiral_vectors = tuple(
|
302 | 301 | jnp.array((sx, sy))
|
303 |
| - for sx, sy in safe_zip(spiral_vectors, spiral_vectors_y) |
| 302 | + for sx, sy in zip(spiral_vectors, spiral_vectors_y, strict=True) |
304 | 303 | )
|
305 | 304 | else:
|
306 | 305 | peps_tensors, unitcell = _map_tensors(
|
|
0 commit comments