Skip to content

Commit

Permalink
algebra: mps to_block2, complex
Browse files Browse the repository at this point in the history
  • Loading branch information
hczhai committed Jan 17, 2024
1 parent 070df83 commit f1f1091
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 26 deletions.
43 changes: 37 additions & 6 deletions pyblock2/algebra/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,18 @@ def __mul__(self, other):
"""Scalar multiplication."""
return SubTensor(q_labels=self.q_labels, reduced=other * self.reduced)

def __truediv__(self, other):
"""Scalar division."""
return SubTensor(q_labels=self.q_labels, reduced=self.reduced / other)

def __neg__(self):
"""Times (-1)."""
return SubTensor(q_labels=self.q_labels, reduced=-self.reduced)

def conj(self):
"""Complex conjugate."""
return SubTensor(q_labels=self.q_labels, reduced=np.conj(self.reduced))

def equal_shape(self, other):
"""Test if two blocks have equal shape and quantum labels."""
return self.q_labels == other.q_labels and self.reduced.shape == other.reduced.shape
Expand Down Expand Up @@ -460,10 +468,18 @@ def __mul__(self, other):
"""Scalar multiplication."""
return Tensor(blocks=[block * other for block in self.blocks])

def __truediv__(self, other):
"""Scalar division."""
return Tensor(blocks=[block / other for block in self.blocks])

def __neg__(self):
"""Times (-1)."""
return Tensor(blocks=[-block for block in self.blocks])

def conj(self):
"""Complex conjugate."""
return Tensor(blocks=[block.conj() for block in self.blocks])

def __repr__(self):
return "\n".join("%3d %r" % (ib, b) for ib, b in enumerate(self.blocks))

Expand Down Expand Up @@ -545,13 +561,21 @@ def __mul__(self, other):
"""Scalar multiplication."""
return MPS(tensors=[self.tensors[0] * other] + self.tensors[1:])

def __truediv__(self, other):
"""Scalar division."""
return MPS(tensors=[self.tensors[0] / other] + self.tensors[1:])

def __rmul__(self, other):
return self * other

def __neg__(self):
"""Times (-1)."""
return MPS(tensors=[-self.tensors[0]] + self.tensors[1:])

def conj(self):
"""Complex conjugate."""
return MPS(tensors=[ts.conj() for ts in self.tensors])

def __add__(self, other):
"""Add two MPS. data in `other` MPS will be put in larger reduced indices."""
assert isinstance(other, MPS)
Expand Down Expand Up @@ -581,7 +605,7 @@ def __add__(self, other):
mshape[0] = lb[q[0]]
if i != self.n_sites - 1:
mshape[-1] = rb[q[-1]]
sub_mp[q] = SubTensor(q, np.zeros(tuple(mshape)))
sub_mp[q] = SubTensor(q, np.zeros(tuple(mshape), dtype=block.reduced.dtype))
# copy block self.blocks to smaller index in new block
for block in self.tensors[i].blocks:
q = block.q_labels
Expand Down Expand Up @@ -633,7 +657,7 @@ def __or__(self, other):
else:
lbra = Tensor.contract(left, self.tensors[i], [0], [0])
left = Tensor.contract(lbra, other.tensors[i], cidx, cidx)
assert isinstance(left, float)
assert isinstance(left, float) or isinstance(left, complex)

return left

Expand Down Expand Up @@ -669,7 +693,7 @@ def __matmul__(self, other):
else:
left = Tensor.contract(
lbra, other.tensors[i], [0, 1], [0, 1])
assert isinstance(left, float)
assert isinstance(left, float) or isinstance(left, complex)

return left

Expand Down Expand Up @@ -785,7 +809,7 @@ def merge_virtual_dims(self):
xsh = block.reduced.shape[:-2] + (nrr,)
if q not in map_blocks:
map_blocks[q] = SubTensor(
q_labels=q, reduced=np.zeros(sh))
q_labels=q, reduced=np.zeros(sh, dtype=block.reduced.dtype))
map_blocks[q].reduced[..., irr: irr
+ nrr] = block.reduced.reshape(xsh)
elif i == self.n_sites - 1:
Expand All @@ -795,7 +819,7 @@ def merge_virtual_dims(self):
xsh = (nll,) + block.reduced.shape[2:]
if q not in map_blocks:
map_blocks[q] = SubTensor(
q_labels=q, reduced=np.zeros(sh))
q_labels=q, reduced=np.zeros(sh, dtype=block.reduced.dtype))
map_blocks[q].reduced[ill:ill
+ nll, ...] = block.reduced.reshape(xsh)
else:
Expand All @@ -805,7 +829,7 @@ def merge_virtual_dims(self):
xsh = (nll,) + block.reduced.shape[2:-2] + (nrr,)
if q not in map_blocks:
map_blocks[q] = SubTensor(
q_labels=q, reduced=np.zeros(sh))
q_labels=q, reduced=np.zeros(sh, dtype=block.reduced.dtype))
map_blocks[q].reduced[ill:ill + nll, ...,
irr:irr + nrr] = block.reduced.reshape(xsh)
self.tensors[i] = Tensor(blocks=list(map_blocks.values()))
Expand Down Expand Up @@ -849,11 +873,18 @@ def __mul__(self, other):
"""Scalar multiplication."""
return MPO(tensors=[self.tensors[0] * other] + self.tensors[1:], const_e=other * self.const_e)

def __truediv__(self, other):
"""Scalar division."""
return MPO(tensors=[self.tensors[0] / other] + self.tensors[1:], const_e=self.const_e / other)

def __neg__(self):
"""Times (-1)."""
return MPO(tensors=[-self.tensors[0]] + self.tensors[1:], const_e=-self.const_e)

def conj(self):
"""Complex conjugate."""
return MPO(tensors=[ts.conj() for ts in self.tensors], const_e=np.conj(self.const_e))

def __add__(self, other):
"""Add two MPO. data in `other` MPO will be put in larger reduced indices."""
assert isinstance(other, MPO)
Expand Down
169 changes: 149 additions & 20 deletions pyblock2/algebra/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@

import numpy as np
from .core import MPS, MPO, Tensor, SubTensor
from block2 import OpTypes, QCTypes, SZ
from block2.sz import StateInfo, MPOQC


class TensorTools:
Expand Down Expand Up @@ -96,7 +94,6 @@ def from_block2_right_fused(bspmat, m, r, mr, cmr):
if bspmat.info.is_wavefunction:
qmr = -qmr
ik = mr.find_state(qmr)
kked = cmr.n if ik == mr.n - 1 else cmr.n_states[ik + 1]
pmat = np.array(bspmat[i])
nl = pmat.shape[0]
ip = 0
Expand Down Expand Up @@ -168,9 +165,9 @@ def from_block2(bmps):
bmps.info.load_left_dims(i)
l = bmps.info.left_dims[i]
m = bmps.info.basis[i]
lm = StateInfo.tensor_product_ref(
lm = m.__class__.tensor_product_ref(
l, m, bmps.info.left_dims_fci[i + 1])
clm = StateInfo.get_connection_info(l, m, lm)
clm = m.__class__.get_connection_info(l, m, lm)
bmps.load_tensor(i)
if i == bmps.n_sites - 1 and i == bmps.center and bmps.dot == 1:
if bmps.tensors[i].info.n == 1 and \
Expand All @@ -192,15 +189,15 @@ def from_block2(bmps):
bmps.info.load_right_dims(i + 1)
m = bmps.info.basis[i]
r = bmps.info.right_dims[i + 1]
mr = StateInfo.tensor_product_ref(
mr = m.__class__.tensor_product_ref(
m, r, bmps.info.right_dims_fci[i])
else:
bmps.info.load_right_dims(i + 2)
m = bmps.info.basis[i + 1]
r = bmps.info.right_dims[i + 2]
mr = StateInfo.tensor_product_ref(
mr = m.__class__.tensor_product_ref(
m, r, bmps.info.right_dims_fci[i + 1])
cmr = StateInfo.get_connection_info(m, r, mr)
cmr = m.__class__.get_connection_info(m, r, mr)
bmps.load_tensor(i)
tensors[i] = TensorTools.from_block2_right_fused(
bmps.tensors[i], m, r, mr, cmr)
Expand All @@ -214,12 +211,12 @@ def from_block2(bmps):
ma = bmps.info.basis[i]
mb = bmps.info.basis[i + 1]
r = bmps.info.right_dims[i + 2]
lm = StateInfo.tensor_product_ref(
lm = ma.__class__.tensor_product_ref(
l, ma, bmps.info.left_dims_fci[i + 1])
mr = StateInfo.tensor_product_ref(
mr = ma.__class__.tensor_product_ref(
mb, r, bmps.info.right_dims_fci[i + 1])
clm = StateInfo.get_connection_info(l, m, lm)
cmr = StateInfo.get_connection_info(m, r, mr)
clm = ma.__class__.get_connection_info(l, ma, lm)
cmr = ma.__class__.get_connection_info(mb, r, mr)
bmps.load_tensor(i)
tensors[i] = TensorTools.from_block2_left_and_right_fused(
bmps.tensors[i], l, ma, mb, r, lm, clm, mr, cmr)
Expand All @@ -246,13 +243,118 @@ def from_block2(bmps):
assert False
return MPS(tensors=tensors)

@staticmethod
def to_block2(mps, basis, center=0, tag='KET'):
"""
Translate pyblock2 MPS to block2 MPS.
Args:
mps : pyblock2 MPS
More than one physical index is not supported.
But fused index can be supported.
center : int
The pyblock2 MPS is transformed after
canonicalization at the given center site.
basis : List(Counter)
Phyiscal basis infomation at each site.
tag : str
Tag of the block2 MPS. Default is "KET".
Returns:
bmps : block2 MPS
To inspect this MPS, please make sure that the block2 global
scratch folder and stack memory are properly initialized.
"""
import block2 as b
Q = mps.tensors[0].blocks[0].q_labels[0].__class__
DT = mps.tensors[0].blocks[0].reduced.dtype
if Q == b.SZ and DT == np.complex128:
import block2.cpx.sz as bs, block2.sz as brs, block2.cpx as bx
elif Q == b.SZ and DT == np.float64:
import block2.sz as bs, block2.sz as brs, block2 as bx
elif Q == b.SU2 and DT == np.complex128:
import block2.cpx.su2 as bs, block2.su2 as brs, block2.cpx as bx
elif Q == b.SU2 and DT == np.float64:
import block2.su2 as bs, block2.su2 as brs, block2 as bx
elif Q == b.SGF and DT == np.complex128:
import block2.cpx.sgf as bs, block2.sgf as brs, block2.cpx as bx
elif Q == b.SGF and DT == np.float64:
import block2.sgf as bs, block2.sgf as brs, block2 as bx
elif Q == b.SGB and DT == np.complex128:
import block2.cpx.sgb as bs, block2.sgb as brs, block2.cpx as bx
elif Q == b.SGB and DT == np.float64:
import block2.sgb as bs, block2.sgb as brs, block2 as bx
else:
raise RuntimeError("Q = %s DT = %s not supported!" % (Q, DT))
if b.Global.frame is None:
raise RuntimeError("block2 is not initialized!")
save_dir = b.Global.frame.save_dir
mps.canonicalize(center)
n_sites = len(mps.tensors)
ql = mps.tensors[0].blocks[0].q_labels
qr = mps.tensors[-1].blocks[0].q_labels
vacuum = (ql[1] - ql[0])[0]
target = (qr[0] + qr[1])[0]
info = brs.MPSInfo(n_sites, vacuum, target, basis)
info.tag = tag
info.set_bond_dimension_full_fci(vacuum, vacuum)
info.left_dims[0] = brs.StateInfo(vacuum)
for i, xinfo in enumerate(mps.get_left_dims()):
p = info.left_dims[i + 1]
p.allocate(len(xinfo))
for ix, (k, v) in enumerate(xinfo.items()):
p.quanta[ix] = k
p.n_states[ix] = v
p.sort_states()
info.left_dims[n_sites] = brs.StateInfo(target)
info.right_dims[0] = brs.StateInfo(target)
for i, xinfo in enumerate(mps.get_right_dims()):
p = info.right_dims[i + 1]
p.allocate(len(xinfo))
for ix, (k, v) in enumerate(xinfo.items()):
p.quanta[ix] = (target - k)[0]
p.n_states[ix] = v
p.sort_states()
info.right_dims[n_sites] = brs.StateInfo(vacuum)
info.save_mutable()
info.save_data("%s/%s-mps_info.bin" % (save_dir, tag))
tensors = [bs.SparseTensor() for _ in range(n_sites)]
for i, bb in enumerate(basis):
tensors[i].data = bs.VectorVectorPSSTensor([bs.VectorPSSTensor() for _ in range(bb.n)])
for block in mps[i].blocks:
if i == 0:
ql = vacuum
qm, qr = block.q_labels
blk = block.reduced.reshape((1, *block.reduced.shape))
elif i == n_sites - 1:
ql, qm = block.q_labels
qr = target
blk = block.reduced.reshape((*block.reduced.shape, 1))
else:
ql, qm, qr = block.q_labels
blk = block.reduced
im = bb.find_state(qm)
assert im != -1
tensors[i].data[im].append(((ql, qr), bx.Tensor(b.VectorMKLInt(blk.shape))))
np.array(tensors[i].data[im][-1][1], copy=False)[:] = blk

umps = bs.UnfusedMPS()
umps.info = info
umps.n_sites = n_sites
umps.canonical_form = "L" * center + ("S" if center == n_sites - 1 else "K") + \
"R" * (n_sites - center - 1)
umps.center = center
umps.dot = 1
umps.tensors = bs.VectorSpTensor(tensors)
return umps.finalize()

class MPOTools:
@staticmethod
def from_block2(bmpo):
"""Translate block2 (un-simplified) MPO to pyblock2 MPO."""
assert bmpo.schemer is None
if isinstance(bmpo, MPOQC):
from block2 import OpTypes, QCTypes
if bmpo.__class__.__name__ == "MPOQC":
assert bmpo.mode == QCTypes.NC or bmpo.mode == QCTypes.CN
tensors = [None] * bmpo.n_sites
# tranlate operator name symbols to quantum labels
Expand Down Expand Up @@ -292,11 +394,12 @@ def from_block2(bmpo):
nu = spmat.info.n_states_bra[p]
nd = spmat.info.n_states_ket[p]
qx = (qu, qd, qr)
spm = np.array(spmat[p])
if qx not in map_blocks:
map_blocks[qx] = SubTensor(
q_labels=qx, reduced=np.zeros((nu, nd, nr)))
q_labels=qx, reduced=np.zeros((nu, nd, nr), dtype=spm.dtype))
map_blocks[qx].reduced[:, :, ir] += expr.factor * \
spmat.factor * np.array(spmat[p])
spmat.factor * spm
else:
assert False
elif i == bmpo.n_sites - 1:
Expand All @@ -317,11 +420,12 @@ def from_block2(bmpo):
nu = spmat.info.n_states_bra[p]
nd = spmat.info.n_states_ket[p]
qx = (ql, qu, qd)
spm = np.array(spmat[p])
if qx not in map_blocks:
map_blocks[qx] = SubTensor(
q_labels=qx, reduced=np.zeros((nl, nu, nd)))
q_labels=qx, reduced=np.zeros((nl, nu, nd), dtype=spm.dtype))
map_blocks[qx].reduced[il, :, :] += expr.factor * \
spmat.factor * np.array(spmat[p])
spmat.factor * spm
else:
assert False
else:
Expand All @@ -343,13 +447,38 @@ def from_block2(bmpo):
nu = spmat.info.n_states_bra[p]
nd = spmat.info.n_states_ket[p]
qx = (ql, qu, qd, qr)
if np.linalg.norm(np.array(spmat[p])) == 0:
spm = np.array(spmat[p])
if np.linalg.norm(spm) == 0:
continue
if qx not in map_blocks:
map_blocks[qx] = SubTensor(
q_labels=qx, reduced=np.zeros((nl, nu, nd, nr)))
q_labels=qx, reduced=np.zeros((nl, nu, nd, nr), dtype=spm.dtype))
map_blocks[qx].reduced[il, :, :,
ir] += expr.factor * spmat.factor * np.array(spmat[p])
ir] += expr.factor * spmat.factor * spm
elif expr.get_type() == OpTypes.Sum:
for xexpr in expr.strings:
spmat = ops[xexpr.a.abs()]
if spmat.factor == 0 or spmat.info.n == 0:
continue
ql, qr = idx_qss[i - 1][j], idx_qss[i][k]
nl, nr = len(idx_mps[i - 1][ql]
), len(idx_mps[i][qr])
il, ir = idx_imps[i - 1][j], idx_imps[i][k]
for p in range(spmat.info.n):
qu = spmat.info.quanta[p].get_bra(
spmat.info.delta_quantum)
qd = spmat.info.quanta[p].get_ket()
nu = spmat.info.n_states_bra[p]
nd = spmat.info.n_states_ket[p]
qx = (ql, qu, qd, qr)
spm = np.array(spmat[p])
if np.linalg.norm(spm) == 0:
continue
if qx not in map_blocks:
map_blocks[qx] = SubTensor(
q_labels=qx, reduced=np.zeros((nl, nu, nd, nr), dtype=spm.dtype))
map_blocks[qx].reduced[il, :, :,
ir] += xexpr.factor * spmat.factor * spm
else:
assert False
tensors[i] = Tensor(blocks=list(map_blocks.values()))
Expand Down

0 comments on commit f1f1091

Please sign in to comment.