@@ -903,6 +903,13 @@ def tensor_multifuse(ts, inds, gauges=None):
903
903
# contract into a single gauge
904
904
gauges [inds [0 ]] = functools .reduce (lambda x , y : do ("kron" , x , y ), gs )
905
905
906
+ if hasattr (ts [0 ].data , "align_axes" ):
907
+ arrays = [t .data for t in ts ]
908
+ axes = [tuple (map (t .inds .index , inds )) for t in ts ]
909
+ arrays = do ("align_axes" , * arrays , axes )
910
+ for t , a in zip (ts , arrays ):
911
+ t .modify (data = a )
912
+
906
913
# index fusing
907
914
for t in ts :
908
915
t .fuse_ ({inds [0 ]: inds })
@@ -4178,6 +4185,19 @@ def conj(self, mangle_inner=False, inplace=False):
4178
4185
append = None if mangle_inner is True else str (mangle_inner )
4179
4186
tn .mangle_inner_ (append )
4180
4187
4188
+ if hasattr (next (iter (tn .tensor_map .values ())), "phase_flip" ):
4189
+ # need to phase dual outer indices
4190
+ outer_inds = tn .outer_inds ()
4191
+ for t in tn :
4192
+ data = t .data
4193
+ dual_outer_axs = tuple (
4194
+ ax
4195
+ for ax , ix in enumerate (t .inds )
4196
+ if (ix in outer_inds ) and not data .indices [ax ].dual
4197
+ )
4198
+ if dual_outer_axs :
4199
+ t .modify (data = data .phase_flip (* dual_outer_axs ))
4200
+
4181
4201
return tn
4182
4202
4183
4203
conj_ = functools .partialmethod (conj , inplace = True )
0 commit comments