Skip to content

Commit

Permalink
eig solver for root finder
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Sep 17, 2023
1 parent 6692099 commit a88d65e
Show file tree
Hide file tree
Showing 7 changed files with 364 additions and 148 deletions.
13 changes: 11 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ configure_file(
set(cintSrc
src/c2f.c src/cart2sph.c src/cint1e.c src/cint2e.c src/cint_bas.c
src/fblas.c src/g1e.c src/g2e.c src/misc.c src/optimizer.c
src/fmt.c src/rys_wheeler.c src/eigh.c src/rys_roots.c
src/fmt.c src/rys_wheeler.c src/eigh.c src/rys_roots.c src/find_roots.c
src/cint2c2e.c src/g2c2e.c src/cint3c2e.c src/g3c2e.c
src/cint3c1e.c src/g3c1e.c src/breit.c
src/cint1e_a.c src/cint3c1e_a.c
Expand All @@ -90,9 +90,18 @@ if(WITH_RANGE_COULOMB)
# defined in config.h
# add_definitions(-DWITH_RANGE_COULOMB)
message("Enabled WITH_RANGE_COULOMB")
set(cintSrc ${cintSrc} src/polyfits.c src/sr_rys_polyfits.c)
if(WITH_POLYNOMIAL_FIT)
set(cintSrc ${cintSrc} src/polyfits.c src/sr_rys_polyfits.c)
add_definitions(-DWITH_POLYNOMIAL_FIT)
message("Enabled WITH_POLYNOMIAL_FIT")
endif(WITH_POLYNOMIAL_FIT)
endif(WITH_RANGE_COULOMB)

if(WITH_EIG_ROOTFINDER)
add_definitions(-DWITH_EIG_ROOTFINDER)
message("Enabled linear equation solver with eigenvalue algorithm")
endif(WITH_EIG_ROOTFINDER)

if(WITH_COULOMB_ERF)
set(cintSrc ${cintSrc} src/g2e_coulerf.c src/cint2e_coulerf.c)
add_definitions(-DWITH_COULOMB_ERF)
Expand Down
7 changes: 4 additions & 3 deletions scripts/find_polyroots.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def find_polyroots(cs, nroots):
for m in range(nroots-1):
A[m+1,m] = mpmath.mpf(1)
for m in range(nroots):
A[0,m] = -cs[nroots,nroots-1-m] / cs[nroots,nroots]
A[0,nroots-1-m] = -cs[nroots,m] / cs[nroots,nroots]
roots = eig(A)
return np.array(roots[::-1])

Expand Down Expand Up @@ -230,8 +230,9 @@ def hessenberg_qr(A, eps, maxits=120):

while k + 1 < n1:
s = abs(A[k,k]) + abs(A[k+1,k+1])
if s < eps:
s = 1
#if s < eps:
# s = 1
# Ensure relative error converged
if abs(A[k+1,k]) < eps * s:
break
k += 1
Expand Down
7 changes: 5 additions & 2 deletions scripts/rys_tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def clenshaw_points(n):
ngrids = 14
chebrt = np.array(chebyshev_roots(ngrids))
cs = np.array(clenshaw_points(ngrids))
TBASE = np.arange(0, 51, 2.5)
TBASE = np.append(np.arange(0, 39, 2.5), np.arange(40, 104, 4))
print(TBASE)

def get_cheb_t_points(tbase):
Expand Down Expand Up @@ -93,14 +93,17 @@ def polyfit_erf(nroots, x):
def pkl2table(prefix, pklfile):
with open(pklfile, 'rb') as f:
TBASE, rys_tab = pickle.load(f)
nt = find_tbase(81)
TBASE = TBASE.round(6)
TBASE = TBASE[:nt+1]
with open(f'{prefix}_x.dat', 'w') as fx, open(f'{prefix}_w.dat', 'w') as fw:
fw.write(f'// DATA_TBASE[{len(TBASE)}] = ''{' + (', '.join([str(x) for x in TBASE])) + '};\n')
fx.write(f'static double DATA_X[] = ''{\n')
fw.write(f'static double DATA_W[] = ''{\n')
for i, tab in enumerate(rys_tab):
nroots = i + 6
for it, ttab in enumerate(tab):
for it in range(nt):
ttab = tab[it]
tbase = TBASE[it]
print(f'root {nroots} tbase[{it}] {tbase}')
fx.write(f'/* root={nroots} base[{it}]={tbase} */\n')
Expand Down
6 changes: 4 additions & 2 deletions scripts/sr_rys_tabulate_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,22 @@
im = clenshaw_d1(tab_ws.astype(float), uu, nroots)
ww = np.array(clenshaw_d1(im, tt, nroots), dtype=float)
ref = np.array(rys_roots_weights(nroots, x, low), dtype=float)
rr /= rr + 1
ref[0] /= ref[0] + 1
diff1, diff2 = abs((rr - ref[0])/ref[0]).max(), abs((ww - ref[1])).max()
if diff1 > 1e-12 or diff2 > 1e-13:
print(nroots, x, low, diff1, diff2)

np.random.seed(1)
xs = np.sort(np.append(np.random.rand(100) * 50, 3**((np.random.rand(20)-.3) * 5)))
xs = np.sort(np.append(np.random.rand(100) * 100, 3**((np.random.rand(30)-.4) * 6)))
xs[0] = 1e-15
fil = 'rys_rw.pkl'
TBASE, tab = pickle.load(open(fil, 'rb'))
rys_tabulate.TBASE = TBASE

for nroots in range(1, 15):
for x in xs:
if x >= 50:
if x >= 100:
continue

it = np.searchsorted(rys_tabulate.TBASE, x) - 1
Expand Down
294 changes: 294 additions & 0 deletions src/find_roots.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
/*
Modified based on mpmath.matrices.eigen.eig function.
This implementation restricts the eigenvalues to real.
*/

#include <stdio.h>
#include <math.h>

#define SQUARE(X) ((X) * (X))
#define MIN(X,Y) ((X)<(Y)?(X):(Y))
#define MXRYSROOTS 32

#define POLYNOMIAL_VALUE1(p, a, order, x) \
p = a[order]; \
for (i = 1; i <= order; i++) { \
p = p * x + a[order-i]; \
}

static int R_dnode(double *a, double *roots, int order)
{
const double accrt = 1e-15;
double x0, x1, xi, x1init, p0, p1, pi, p1init;
int i, m, n;

x1init = 0;
p1init = a[0];
for (m = 0; m < order; ++m) {
x0 = x1init;
p0 = p1init;
x1init = roots[m];
POLYNOMIAL_VALUE1(p1init, a, order, x1init);

// When all coefficients a are 0, short-circuit the rest code to
// ensure the roots from the lower order polynomials are preserved
if (p1init == 0) {
// roots[m] = x1init;
continue;
}
if (p0 * p1init > 0) {
fprintf(stderr, "ROOT NUMBER %d WAS NOT FOUND FOR POLYNOMIAL OF ORDER %d\n",
m, order);
return 1;
}
if (x0 <= x1init) {
x1 = x1init;
p1 = p1init;
} else {
x1 = x0;
p1 = p0;
x0 = x1init;
p0 = p1init;
}
// interpolate/extrapolate between [x0,x1]
if (p1 == 0) {
roots[m] = x1;
continue;
} else if (p0 == 0) {
roots[m] = x0;
continue;
} else {
xi = x0 + (x0 - x1) / (p1 - p0) * p0;
}
n = 0;
while (fabs(x1 - x0) > x1*accrt) {
n++;
if (n > 200) {
fprintf(stderr, "libcint::rys_roots NO CONV. IN R_dnode\n");
return 1;
}
POLYNOMIAL_VALUE1(pi, a, order, xi);
if (pi == 0) {
break;
} else if (p0 * pi <= 0) {
x1 = xi;
p1 = pi;
xi = x0 * .25 + xi * .75;
} else {
x0 = xi;
p0 = pi;
xi = xi * .75 + x1 * .25;
}
POLYNOMIAL_VALUE1(pi, a, order, xi);
if (pi == 0) {
break;
} else if (p0 * pi <= 0) {
x1 = xi;
p1 = pi;
} else {
x0 = xi;
p0 = pi;
}

xi = x0 + (x0 - x1) / (p1 - p0) * p0;
}
roots[m] = xi;
}
return 0;
}

static void _qr_step(double *A, int nroots, int n0, int n1, double shift)
{
int m1 = n0 + 1;
int j, k, m3, j1, j2;
double c = A[n0*nroots+n0] - shift;
double s = A[m1*nroots+n0];
double v = sqrt(c*c + s*s);
double x, y;

if (v == 0) {
v = 1;
c = 1;
s = 0;
}
v = 1. / v;
c *= v;
s *= v;

for (k = n0; k < nroots; k++) {
// apply givens rotation from the left
x = A[n0*nroots+k];
y = A[m1*nroots+k];
A[n0*nroots+k] = c * x + s * y;
A[m1*nroots+k] = c * y - s * x;
}

m3 = MIN(n1, n0+3);
for (k = 0; k < m3; k++) {
// apply givens rotation from the right
x = A[k*nroots+n0];
y = A[k*nroots+m1];
A[k*nroots+n0] = c * x + s * y;
A[k*nroots+m1] = c * y - s * x;
}

for (j = n0; j < n1 - 2; j++) {
j1 = j + 1;
j2 = j + 2;
// calculate givens rotation
c = A[j1*nroots+j];
s = A[j2*nroots+j];
v = sqrt(c*c + s*s);
A[j1*nroots+j] = v;
A[j2*nroots+j] = 0;

if (v == 0) {
v = 1;
c = 1;
s = 0;
}
v = 1. / v;
c *= v;
s *= v;

for (k = j1; k < nroots; k++) {
// apply givens rotation from the left
x = A[j1*nroots+k];
y = A[j2*nroots+k];
A[j1*nroots+k] = c * x + s * y;
A[j2*nroots+k] = c * y - s * x;
}
m3 = MIN(n1, j+4);
for (k = 0; k < m3; k++) {
// apply givens rotation from the right
x = A[k*nroots+j1];
y = A[k*nroots+j2];
A[k*nroots+j1] = c * x + s * y;
A[k*nroots+j2] = c * y - s * x;
}
}
}

static int _hessenberg_qr(double *A, int nroots)
{
double eps = 1e-15;
int maxits = 30;
int n0 = 0;
int n1 = nroots;
int its = 0;
int k, ic, k1;
for (ic = 0; ic < nroots*maxits; ic++) {
k = n0;
while (k + 1 < n1) {
double s = fabs(A[k*nroots+k]) + fabs(A[(k+1)*nroots+k+1]);
if (fabs(A[(k+1)*nroots+k]) < eps * s) {
break;
}
k += 1;
}

k1 = k + 1;
if (k1 < n1) {
// deflation found at position (k+1, k)
A[k1*nroots+k] = 0;
n0 = k1;
its = 0;

if (n0 + 1 >= n1) {
// block of size at most two has converged
n0 = 0;
n1 = k1;
if (n1 < 2) {
// QR algorithm has converged
return 0;
}
}
} else {
int m1 = n1 - 1;
int m2 = n1 - 2;
double a11 = A[m1*nroots+m1];
double a22 = A[m2*nroots+m2];
double shift;
double t = a11 + a22;
double s = SQUARE(a11 - a22);
s += 4 * A[m1*nroots+m2] * A[m2*nroots+m1];
if (s > 0) {
s = sqrt(s);
double a = (t + s) * .5;
double b = (t - s) * .5;
if (fabs(a11 - a) > fabs(a11 - b)) {
shift = b;
} else {
shift = a;
}
} else {
if (n1 == 2) {
fprintf(stderr, "hessenberg_qr: failed to find real roots\n");
return 1;
}
shift = t * .5;
}
its += 1;
_qr_step(A, nroots, n0, n1, shift);
if (its > maxits) {
fprintf(stderr, "hessenberg_qr: failed to converge after %d steps\n", its);
return 1;
}
}
}
fprintf(stderr, "hessenberg_qr failed\n");
return 1;
}

int _CINT_polynomial_roots(double *roots, double *cs, int nroots)
{
if (nroots == 1) {
roots[0] = -cs[2] / cs[3];
return 0;
} else if (nroots == 2) {
double dum = sqrt(SQUARE(cs[2*3+1]) - 4*cs[2*3+0]*cs[2*3+2]);
roots[0] = (-cs[2*3+1] - dum) / cs[2*3+2] / 2;
roots[1] = (-cs[2*3+1] + dum) / cs[2*3+2] / 2;
return 0;
}

double A[MXRYSROOTS * MXRYSROOTS];
int nroots1 = nroots + 1;
// reuse the buffer in coefficients
int i;
double fac = -1. / cs[nroots*nroots1+nroots];
for (i = 0; i < nroots; i++) {
A[nroots-1-i] = cs[nroots*nroots1+i] * fac;
}
for (i = nroots; i < nroots*nroots; i++) {
A[i] = 0;
}
for (i = 0; i < nroots-1; i++) {
A[(i+1)*nroots+i] = 1.;
}
int err = _hessenberg_qr(A, nroots);
if (err == 0) {
for (i = 0; i < nroots; i++) {
roots[nroots-1-i] = A[i*nroots+i];
}
} else {
int k, order;
double *a;
double dum = sqrt(cs[2*nroots1+1] * cs[2*nroots1+1]
- 4 * cs[2*nroots1+0] * cs[2*nroots1+2]);
roots[0] = .5 * (-cs[2*nroots1+1] - dum) / cs[2*nroots1+2];
roots[1] = .5 * (-cs[2*nroots1+1] + dum) / cs[2*nroots1+2];
for (i = 2; i < nroots; i++) {
roots[i] = 1;
}
for (k = 2; k < nroots; ++k) {
order = k + 1;
a = cs + order * nroots1;
err = R_dnode(a, roots, order);
if (err) {
break;
}
}
}
return err;
}
Loading

0 comments on commit a88d65e

Please sign in to comment.