Skip to content

Commit 9df4fed

Browse files
committed
TNOptimizer: warn and cast for precision loss (jax)
1 parent 2e13857 commit 9df4fed

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

quimb/tensor/optimize.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,18 @@ def pack(self, tree, name="vector"):
129129
for array, info in zip(arrays, self.infos):
130130
if not isinstance(array, np.ndarray):
131131
array = to_numpy(array)
132+
133+
if array.dtype != info.dtype:
134+
warnings.warn(
135+
"dtype mismatch between input parameter and updated "
136+
"values. This can occur e.g. with jax and double "
137+
"precision arrays (in which case consider setting "
138+
'`jax.config.update("jax_enable_x64", True)` at startup '
139+
"or using single precision parameters directly)."
140+
f" For now casting from {array.dtype} to {info.dtype}."
141+
)
142+
array = array.astype(info.dtype)
143+
132144
# flatten
133145
if info.iscomplex:
134146
# view as real array of double the length

0 commit comments

Comments
 (0)