Skip to content

Commit f4ff318

Browse files
committed
add some symmray support
1 parent 7cd4f4a commit f4ff318

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

quimb/tensor/tensor_core.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,13 @@ def tensor_multifuse(ts, inds, gauges=None):
903903
# contract into a single gauge
904904
gauges[inds[0]] = functools.reduce(lambda x, y: do("kron", x, y), gs)
905905

906+
if hasattr(ts[0].data, "align_axes"):
907+
arrays = [t.data for t in ts]
908+
axes = [tuple(map(t.inds.index, inds)) for t in ts]
909+
arrays = do("align_axes", *arrays, axes)
910+
for t, a in zip(ts, arrays):
911+
t.modify(data=a)
912+
906913
# index fusing
907914
for t in ts:
908915
t.fuse_({inds[0]: inds})
@@ -4178,6 +4185,19 @@ def conj(self, mangle_inner=False, inplace=False):
41784185
append = None if mangle_inner is True else str(mangle_inner)
41794186
tn.mangle_inner_(append)
41804187

4188+
if hasattr(next(iter(tn.tensor_map.values())), "phase_flip"):
4189+
# need to phase dual outer indices
4190+
outer_inds = tn.outer_inds()
4191+
for t in tn:
4192+
data = t.data
4193+
dual_outer_axs = tuple(
4194+
ax
4195+
for ax, ix in enumerate(t.inds)
4196+
if (ix in outer_inds) and not data.indices[ax].dual
4197+
)
4198+
if dual_outer_axs:
4199+
t.modify(data=data.phase_flip(*dual_outer_axs))
4200+
41814201
return tn
41824202

41834203
conj_ = functools.partialmethod(conj, inplace=True)

0 commit comments

Comments
 (0)