Skip to content

Commit

Permalink
core: rescale const
Browse files Browse the repository at this point in the history
  • Loading branch information
hczhai committed Jul 18, 2024
1 parent bb1f7da commit 1e3a1f5
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
41 changes: 41 additions & 0 deletions pyblock2/driver/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ def __init__(
self.pg = "c1"
self.orb_sym = None
self.ghamil = None
self.n_elec = None
self._dmrg = None
self._sweep_wfn_spectra = None

Expand Down Expand Up @@ -874,6 +875,7 @@ def initialize_system(
pg_irrep = self.pg_irrep
else:
pg_irrep = 0
self.n_elec = n_elec

if target is None and bw.qargs is not None:
if bw.qargs == ("U1Fermi", "AbelianPG"):
Expand Down Expand Up @@ -3195,6 +3197,7 @@ def get_qc_mpo(
normal_order_ref=None,
normal_order_single_ref=None,
normal_order_wick=True,
rescale=None,
symmetrize=True,
sum_mpo_mod=-1,
compute_accurate_svd_error=True,
Expand Down Expand Up @@ -3319,6 +3322,14 @@ def get_qc_mpo(
Only have effect if ``normal_order_ref is not None``.
If True, will use ``WickNormalOrder`` implementation (via automatic symbolic derivation).
Otherwise, will use the manual implementation. Default is True.
rescale : None or float or True
If None, will not rescale (default).
If zero or True, will adjust ``h1e`` and the const energy so that
the average diagonal element of ``h1e`` is zero.
If non-zero float, will adjust ``h1e`` and the const energy so that
the const energy becomes the given ``rescale`` number.
After rescale, the integrals will only be correct for the given
``n_elec``.
symmetrize : bool
Only have effect if ``self.orb_sym is not None`` (when point group symmetry is used).
If True, will symmetrize integrals so that integral elements violating point group restrictions
Expand Down Expand Up @@ -3439,6 +3450,29 @@ def get_qc_mpo(
x_orb_sym, h1e=h1e, g2e=g2e, k_symm=k_symm, iprint=iprint
)

if rescale is not None:
assert h1e is not None
assert self.n_elec is not None
if iprint >= 1:
print("original const = ", ecore)
if SymmetryTypes.SZ in bw.symm_type:
xn = len(h1e[0]) + len(h1e[1])
x = np.trace(h1e[0]) + np.trace(h1e[1])
else:
xn, x = len(h1e), np.trace(h1e)
if isinstance(rescale, int) and rescale == 0:
x = x / xn
else:
x = (rescale - ecore) / self.n_elec
if SymmetryTypes.SZ in bw.symm_type:
h1e[0][np.mgrid[:len(h1e[0])], np.mgrid[:len(h1e[0])]] -= x
h1e[1][np.mgrid[:len(h1e[1])], np.mgrid[:len(h1e[1])]] -= x
else:
h1e[np.mgrid[:len(h1e)], np.mgrid[:len(h1e)]] -= x
ecore += x * self.n_elec
if iprint >= 1:
print("rescaled const = ", ecore)

if integral_cutoff != 0:
error = 0
if SymmetryTypes.SZ in bw.symm_type:
Expand Down Expand Up @@ -3536,6 +3570,10 @@ def get_qc_mpo(
self.ghamil = bw.bs.GeneralHamiltonian(
self.vacuum, self.n_sites, self.orb_sym, self.heis_twos
)
if normal_order_ref is not None:
normal_order_ref = np.array(normal_order_ref)[idx]
if normal_order_single_ref is not None:
normal_order_single_ref = np.array(normal_order_single_ref)[idx]
else:
self.reorder_idx = None

Expand Down Expand Up @@ -7157,6 +7195,7 @@ def get_random_mps(
mps : MPS
The output MPS (normalized).
"""
import numpy as np
bw = self.bw
if target is None:
target = self.target
Expand Down Expand Up @@ -7201,6 +7240,8 @@ def get_random_mps(
else:
mps_info.set_bond_dimension_fci(left_vacuum, self.vacuum)
if occs is not None:
if self.reorder_idx is not None:
occs = np.array(occs)[self.reorder_idx]
mps_info.set_bond_dimension_using_occ(bond_dim, bw.b.VectorDouble(occs))
else:
mps_info.set_bond_dimension(bond_dim)
Expand Down
4 changes: 2 additions & 2 deletions src/dmrg/sweep_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ template <typename S, typename FL, typename FLS> struct DMRG {
false);
int mmps = 0;
FPS error = 0.0;
tuple<FPS, int, size_t, double> pdi;
tuple<FPLS, int, size_t, double> pdi;
shared_ptr<SparseMatrixGroup<S, FLS>> pket = nullptr,
context_pket = nullptr;
shared_ptr<SparseMatrix<S, FLS>> pdm = nullptr;
Expand Down Expand Up @@ -839,7 +839,7 @@ template <typename S, typename FL, typename FLS> struct DMRG {
}
int mmps = 0;
FPS error = 0.0;
tuple<FPS, int, size_t, double> pdi;
tuple<FPLS, int, size_t, double> pdi;
shared_ptr<SparseMatrixGroup<S, FLS>> pket = nullptr,
context_pket = nullptr;
shared_ptr<SparseMatrix<S, FLS>> context_old_ket = nullptr;
Expand Down
2 changes: 1 addition & 1 deletion tests/cr2-gs/SVP
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


BASIS "ao basis" PRINT
#BASIS SET: (14s,9p,5d) -> [5s,3p,2d]
#BASIS SET: (14s,9p,5d) -> [5s,2p,2d]
Cr S
51528.0863490 0.14405823106E-02
7737.2103487 0.11036202287E-01
Expand Down

0 comments on commit 1e3a1f5

Please sign in to comment.