Skip to content

Commit

Permalink
add py.test for csf eval; fix sany csf eval
Browse files Browse the repository at this point in the history
  • Loading branch information
hczhai committed Jul 12, 2024
1 parent cf99d8e commit 0708415
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
10 changes: 10 additions & 0 deletions pyblock2/unit_test/dmrg.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def test_rhf(self, tmp_path, system_def, symm_type):
)
if name == "N2":
assert abs(energy - -107.654122447523) < 1e-6

dets, vals = driver.get_csf_coefficients(ket, cutoff=0.1)
dets, vals = dets[np.argsort(-np.abs(vals))], vals[np.argsort(-np.abs(vals))]
assert abs(abs(vals[0]) - 0.9575065) < 1E-4
assert list(dets[0]) == [3] * 7 + [0] * 3
elif name == "C2":
assert abs(energy - -75.552895292451) < 1e-6

Expand Down Expand Up @@ -179,6 +184,11 @@ def test_uhf(self, tmp_path, system_def, symm_type):
)
if name == "N2":
assert abs(energy - -107.654122447523) < 1e-6

dets, vals = driver.get_csf_coefficients(ket, cutoff=0.1)
dets, vals = dets[np.argsort(-np.abs(vals))], vals[np.argsort(-np.abs(vals))]
assert abs(abs(vals[0]) - 0.9575065) < 1E-4
assert list(dets[0]) == [3] * 7 + [0] * 3
elif name == "C2":
assert abs(energy - -75.552895292345) < 1e-6

Expand Down
6 changes: 4 additions & 2 deletions src/dmrg/determinant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1405,7 +1405,8 @@ struct DeterminantTRIE<S, FL, typename S::is_sany_t>
for (auto &m : mps->tensors[d]->data[jd]) {
S bra = m.first.first, ket = m.first.second;
S jket = bra + mps->info->basis[d]->quanta[jd];
assert(basis_iqs[d][j][1] < jket.count());
if (basis_iqs[d][j][1] >= jket.count())
continue;
if (jket[basis_iqs[d][j][1]] == ket && !qkets.count(ket))
qkets[ket] = m.second->shape[2];
}
Expand Down Expand Up @@ -1466,7 +1467,8 @@ struct DeterminantTRIE<S, FL, typename S::is_sany_t>
for (auto &m : mps->tensors[d]->data[jd]) {
S bra = m.first.first, ket = m.first.second;
S jket = bra + mps->info->basis[d]->quanta[jd];
assert(basis_iqs[d][j][1] < jket.count());
if (basis_iqs[d][j][1] >= jket.count())
continue;
if (jket[basis_iqs[d][j][1]] != ket)
continue;
if (pmp->info->find_state(bra) == -1)
Expand Down

0 comments on commit 0708415

Please sign in to comment.