Skip to content

Commit

Permalink
mpo trans symm
Browse files Browse the repository at this point in the history
  • Loading branch information
hczhai committed Jun 27, 2024
1 parent 3e561e3 commit e9578ca
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 1 deletion.
30 changes: 30 additions & 0 deletions pyblock2/driver/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6798,6 +6798,36 @@ def mps_flip_twos(self, mps):
zmps = umps.finalize(self.prule)
return zmps

def mpo_change_symm(self, mpo, tag="", add_ident=True):
"""
Change symmetry type of MPO.
Only works in SAny mode. The resulting MPO should be used in SAny mode.
Args:
mpo : MPO
The input MPO.
tag : str
The tag of the output MPO.
add_ident : bool
If True, the hidden identity operator will be added into the MPO. Default is True.
Returns:
rmpo : MPO
The output MPO.
"""
bw = self.bw
assert SymmetryTypes.SAny in bw.symm_type
if self.mpi:
rmpo = bw.bs.trans_mpo_to_sany(mpo.prim_mpo.prim_mpo, self.ghamil, tag)
else:
rmpo = bw.bs.trans_mpo_to_sany(mpo.prim_mpo, self.ghamil, tag)
rmpo = bw.bs.SimplifiedMPO(rmpo, bw.bs.Rule(), False, False)
if add_ident:
rmpo = bw.bs.IdentityAddedMPO(rmpo)
if self.mpi:
rmpo = bw.bs.ParallelMPO(rmpo, self.prule)
return rmpo

def mps_change_symm(self, mps, tag, target):
"""
Change symmetry type of MPS.
Expand Down
176 changes: 176 additions & 0 deletions src/dmrg/mpo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1616,4 +1616,180 @@ template <typename S, typename FL> struct IdentityAddedMPO : MPO<S, FL> {
}
};

template <typename S1, typename S2, typename FL> struct TransMPO {
typedef typename GMatrix<FL>::FP FP;
static shared_ptr<MPO<S2, FL>>
forward(const shared_ptr<MPO<S1, FL>> &mpo,
const shared_ptr<Hamiltonian<S2, FL>> &hamil,
const string &tag = "") {
shared_ptr<MPO<S1, FL>> rmpo = make_shared<MPO<S1, FL>>(
mpo->n_sites, tag == "" ? "TR-" + mpo->tag : tag);
const int n_sites = mpo->n_sites;
rmpo->const_e = mpo->const_e;
const S2 ref = hamil->vacuum;
rmpo->tf = make_shared<TensorFunctions<S2, FL>>(hamil->opf);
rmpo->site_op_infos = hamil->site_op_infos;
rmpo->basis = hamil->basis;
rmpo->hamil = hamil;
rmpo->left_operator_names.resize(n_sites, nullptr);
rmpo->right_operator_names.resize(n_sites, nullptr);
rmpo->tensors.resize(n_sites, nullptr);
for (uint16_t m = 0; m < n_sites; m++)
rmpo->tensors[m] = make_shared<OperatorTensor<S2, FL>>();
rmpo->sparse_form = mpo->sparse_form;
rmpo->schemer = nullptr;
assert(mpo->schemer == nullptr);
shared_ptr<OpExpr<S2>> zero = make_shared<OpExpr<S2>>();
auto tr = [&ref](S1 q) -> S2 {
return TransStateInfo<S1, S2>::forward(
make_shared<StateInfo<S1>>(q), ref)
->quanta[0];
};
rmpo->op = make_shared<OpElement<S2, FL>>(
mpo->op->name, mpo->op->site_index, tr(mpo->op->q_label),
mpo->op->factor);
rmpo->left_vacuum = tr(mpo->left_vacuum);
for (int ii = 0; ii < n_sites; ii++) {
mpo->load_tensor(ii);
mpo->load_left_operators(ii);
mpo->load_right_operators(ii);
shared_ptr<OperatorTensor<S2, FL>> opt = rmpo->tensors[ii];
shared_ptr<OperatorTensor<S2, FL>> mopt = mpo->tensors[ii];
shared_ptr<Symbolic<S2>> pmat;
assert(mopt->lmat == mopt->rmat);
if (ii == 0)
pmat = make_shared<SymbolicRowVector<S2>>(mopt->lmat->n);
else if (ii == n_sites - 1)
pmat = make_shared<SymbolicColumnVector<S2>>(mopt->lmat->m);
else {
pmat = make_shared<SymbolicMatrix<S2>>(mopt->lmat->m,
mopt->lmat->n);
dynamic_pointer_cast<SymbolicMatrix<S2>>(pmat)->indices =
dynamic_pointer_cast<SymbolicMatrix<S1>>(mopt->lmat)
->indices;
pmat->data.resize(mopt->lmat->data.size());
}
opt->lmat = opt->rmat = pmat;
for (int iop = 0; iop < mopt->lmat->data.size(); iop++)
if (mopt->lmat->data[iop]->get_type() == OpTypes::Zero)
pmat->data[iop] = zero;
else if (mopt->lmat->data[iop]->get_type() == OpTypes::Elem) {
const auto p = dynamic_pointer_cast<OpElement<S1, FL>>(
mopt->lmat->data[iop]);
pmat->data[iop] = make_shared<OpElement<S2, FL>>(
p->name, p->site_index, tr(p->q_label), p->factor);
} else if (mopt->lmat->data[iop]->get_type() == OpTypes::Sum) {
const auto p = dynamic_pointer_cast<OpSum<S1, FL>>(
mopt->lmat->data[iop]);
vector<shared_ptr<OpExpr<S2>>> strings(p->strings.size());
for (size_t j = 0; j < p->strings.size(); j++) {
const auto pp = dynamic_pointer_cast<OpElement<S1, FL>>(
p->strings[j]);
strings[j] = make_shared<OpElement<S2, FL>>(
pp->name, pp->site_index, tr(pp->q_label),
pp->factor);
}
pmat->data[iop] = sum(strings);
} else
assert(false);
auto lop = make_shared<SymbolicRowVector<S2>>(
mpo->left_operator_names[ii]->n);
auto rop = make_shared<SymbolicColumnVector<S2>>(
mpo->right_operator_names[ii]->m);
for (int iop = 0; iop < lop->data.size(); iop++) {
const auto p = dynamic_pointer_cast<OpElement<S1, FL>>(
mpo->left_operator_names[ii]->data[iop]);
lop->data[iop] = make_shared<OpElement<S2, FL>>(
p->name, p->site_index, tr(p->q_label), p->factor);
}
for (int iop = 0; iop < rop->data.size(); iop++) {
const auto p = dynamic_pointer_cast<OpElement<S1, FL>>(
mpo->right_operator_names[ii]->data[iop]);
rop->data[iop] = make_shared<OpElement<S2, FL>>(
p->name, p->site_index, tr(p->q_label), p->factor);
}
rmpo->left_operator_names[ii] = lop;
rmpo->right_operator_names[ii] = rop;
shared_ptr<VectorAllocator<FP>> d_alloc =
make_shared<VectorAllocator<FP>>();
shared_ptr<StateInfo<S1>> conn =
TransStateInfo<S2, S1>::backward_connection(mpo->basis[ii],
rmpo->basis[ii]);
for (const auto &op : mopt->ops) {
const auto p =
dynamic_pointer_cast<OpElement<S1, FL>>(op.first);
auto q = make_shared<OpElement<S2, FL>>(
p->name, p->site_index, tr(p->q_label), p->factor);
shared_ptr<SparseMatrix<S2, FL>> xmat =
make_shared<SparseMatrix<S2, FL>>(d_alloc);
xmat->allocate(hamil->find_site_op_info(ii, q->q_label));
xmat->factor = op.second->factor;
for (int k = 0; k < op.second->info->n; k++) {
S1 plu = op.second->info->quanta[k].get_bra(
op.second->info->delta_quantum);
S1 pru = op.second->info->quanta[k].get_ket();
GMatrix<FL> r = (*op.second)[k];
shared_ptr<StateInfo<S2>> mls =
TransStateInfo<S1, S2>::forward(
make_shared<StateInfo<S1>>(plu), ref);
shared_ptr<StateInfo<S2>> mrs =
TransStateInfo<S1, S2>::forward(
make_shared<StateInfo<S1>>(pru), ref);
for (int iln = 0; iln < mls->n; iln++)
for (int irn = 0; irn < mrs->n; irn++) {
S2 lqn = mls->quanta[iln], rqn = mrs->quanta[irn];
GMatrix<FL> xr =
(*xmat)[q->q_label.combine(lqn, rqn)];
int il = rmpo->basis[ii]->find_state(lqn);
int ir = rmpo->basis[ii]->find_state(rqn);
MKL_INT zl = rmpo->basis[ii]->n_states[il],
zr = rmpo->basis[ii]->n_states[ir];
int klst = conn->n_states[il];
int krst = conn->n_states[ir];
int kled = il == rmpo->basis[ii]->n - 1
? conn->n
: conn->n_states[il + 1];
int kred = ir == rmpo->basis[ii]->n - 1
? conn->n
: conn->n_states[ir + 1];
size_t lsh = 0, rsh = 0;
for (int ilp = klst;
ilp < kled && conn->quanta[ilp] != plu; ilp++)
lsh +=
mpo->basis[ii]
->n_states[mpo->basis[ii]->find_state(
conn->quanta[ilp])];
for (int irp = krst;
irp < kred && conn->quanta[irp] != pru; irp++)
rsh +=
mpo->basis[ii]
->n_states[mpo->basis[ii]->find_state(
conn->quanta[irp])];
MKL_INT kl =
(MKL_INT)mpo->basis[ii]
->n_states[mpo->basis[ii]->find_state(plu)];
MKL_INT kr =
(MKL_INT)mpo->basis[ii]
->n_states[mpo->basis[ii]->find_state(pru)];
for (MKL_INT ikl = 0; ikl < kl; ikl++)
for (MKL_INT ikr = 0; ikr < kr; ikr++)
xr(ikl + lsh, ikr + rsh) = r(ikl, ikr);
}
}
opt->ops[q] = xmat;
}
mpo->unload_tensor(ii);
mpo->unload_left_operators(ii);
mpo->unload_right_operators(ii);
rmpo->save_tensor(ii);
rmpo->unload_tensor(ii);
rmpo->save_left_operators(ii);
rmpo->unload_left_operators(ii);
rmpo->save_right_operators(ii);
rmpo->unload_right_operators(ii);
}
return rmpo;
}
};

} // namespace block2
2 changes: 2 additions & 0 deletions src/instantiation/block2_dmrg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,7 @@ extern template struct block2::MPO<block2::SAny, double>;
extern template struct block2::DiagonalMPO<block2::SAny, double>;
extern template struct block2::AncillaMPO<block2::SAny, double>;
extern template struct block2::IdentityAddedMPO<block2::SAny, double>;
extern template struct block2::TransMPO<block2::SAny, block2::SAny, double>;

// mpo_fusing.hpp
extern template struct block2::StackedMPO<block2::SAny, double>;
Expand Down Expand Up @@ -1492,6 +1493,7 @@ extern template struct block2::MPO<block2::SAny, complex<double>>;
extern template struct block2::DiagonalMPO<block2::SAny, complex<double>>;
extern template struct block2::AncillaMPO<block2::SAny, complex<double>>;
extern template struct block2::IdentityAddedMPO<block2::SAny, complex<double>>;
extern template struct block2::TransMPO<block2::SAny, block2::SAny, complex<double>>;

// mpo_fusing.hpp
extern template struct block2::StackedMPO<block2::SAny, complex<double>>;
Expand Down
1 change: 1 addition & 0 deletions src/instantiation/dmrg_a/mpo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ template struct block2::MPO<block2::SAny, double>;
template struct block2::DiagonalMPO<block2::SAny, double>;
template struct block2::AncillaMPO<block2::SAny, double>;
template struct block2::IdentityAddedMPO<block2::SAny, double>;
template struct block2::TransMPO<block2::SAny, block2::SAny, double>;
1 change: 1 addition & 0 deletions src/instantiation/dmrg_az/mpo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ template struct block2::MPO<block2::SAny, complex<double>>;
template struct block2::DiagonalMPO<block2::SAny, complex<double>>;
template struct block2::AncillaMPO<block2::SAny, complex<double>>;
template struct block2::IdentityAddedMPO<block2::SAny, complex<double>>;
template struct block2::TransMPO<block2::SAny, block2::SAny, complex<double>>;
5 changes: 4 additions & 1 deletion src/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,12 @@ PYBIND11_MODULE(block2, m) {
bind_trans_mps<SAny, SAny>(m_sany, "sany");
bind_trans_multi_mps<SAny, SAny>(m_sany, "sany");
bind_fl_trans_mps_spin_specific<SAny, SAny, double>(m_sany, "sany");
bind_fl_trans_mpo<SAny, SAny, double>(m_sany, "sany");
#ifdef _USE_COMPLEX
bind_dmrg<SAny, complex<double>>(m_sany_cpx, "SAny");
bind_fl_trans_mps_spin_specific<SAny, SAny, complex<double>>(m_sany_cpx, "sany");
bind_fl_trans_mps_spin_specific<SAny, SAny, complex<double>>(m_sany_cpx,
"sany");
bind_fl_trans_mpo<SAny, SAny, complex<double>>(m_sany_cpx, "sany");
#endif
#endif

Expand Down
3 changes: 3 additions & 0 deletions src/pybind/dmrg_a/trans_mps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ template auto
bind_fl_trans_mps_spin_specific<SAny, SAny, double>(py::module &m,
const string &aux_name)
-> decltype(typename SAny::is_sany_t(typename SAny::is_sany_t()));
template auto bind_fl_trans_mpo<SAny, SAny, double>(py::module &m,
const string &aux_name)
-> decltype(typename SAny::is_sany_t(typename SAny::is_sany_t()));
4 changes: 4 additions & 0 deletions src/pybind/dmrg_az/trans_mps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@
template auto bind_fl_trans_mps_spin_specific<SAny, SAny, complex<double>>(
py::module &m, const string &aux_name)
-> decltype(typename SAny::is_sany_t(typename SAny::is_sany_t()));
template auto
bind_fl_trans_mpo<SAny, SAny, complex<double>>(py::module &m,
const string &aux_name)
-> decltype(typename SAny::is_sany_t(typename SAny::is_sany_t()));
14 changes: 14 additions & 0 deletions src/pybind/pybind_dmrg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2265,6 +2265,13 @@ void bind_fl_trans_mps_spin_specific(py::module &m, const string &aux_name) {
&TransUnfusedMPS<S, T, FL>::forward);
}

template <typename S, typename T, typename FL>
void bind_fl_trans_mpo(py::module &m, const string &aux_name) {

m.def(("trans_mpo_to_" + aux_name).c_str(), &TransMPO<S, T, FL>::forward,
py::arg("mpo"), py::arg("hamil"), py::arg("tag") = "");
}

template <typename S = void> void bind_dmrg_types(py::module &m) {

py::enum_<TruncationTypes>(m, "TruncationTypes", py::arithmetic())
Expand Down Expand Up @@ -2928,6 +2935,9 @@ extern template auto
bind_fl_trans_mps_spin_specific<SAny, SAny, double>(py::module &m,
const string &aux_name)
-> decltype(typename SAny::is_sany_t(typename SAny::is_sany_t()));
extern template auto
bind_fl_trans_mpo<SAny, SAny, double>(py::module &m, const string &aux_name)
-> decltype(typename SAny::is_sany_t(typename SAny::is_sany_t()));

#ifdef _USE_COMPLEX
extern template void bind_fl_general<SAny, complex<double>>(py::module &m);
Expand Down Expand Up @@ -2962,6 +2972,10 @@ extern template auto
bind_fl_trans_mps_spin_specific<SAny, SAny, complex<double>>(
py::module &m, const string &aux_name)
-> decltype(typename SAny::is_sany_t(typename SAny::is_sany_t()));
extern template auto
bind_fl_trans_mpo<SAny, SAny, complex<double>>(py::module &m,
const string &aux_name)
-> decltype(typename SAny::is_sany_t(typename SAny::is_sany_t()));
#endif

#endif
Expand Down

0 comments on commit e9578ca

Please sign in to comment.