diff --git a/scipy_dae/integrate/_dae/radau.py b/scipy_dae/integrate/_dae/radau.py index 4023f6f..175c3e4 100644 --- a/scipy_dae/integrate/_dae/radau.py +++ b/scipy_dae/integrate/_dae/radau.py @@ -494,7 +494,6 @@ def _step_impl(self): Z0 = np.zeros((s, y.shape[0])) else: Z0 = self.sol(t + h * C)[0].T - y - Z0 = np.zeros((s, y.shape[0])) scale = atol + np.abs(y) * rtol converged = False @@ -701,38 +700,33 @@ def _call_impl(self, t): x = (t - self.t_old) / self.h x = np.atleast_1d(x) - # factors for nterpolation polynomial and its derivative - # p = np.tile(x, (self.order + 1, 1)) - # p = np.cumprod(p, axis=0) + # factors for interpolation polynomial and its derivative c = np.arange(1, self.order + 2)[:, None] p = x**c dp = (c / self.h) * (x**(c - 1)) - # Here we don't multiply by h, not a mistake. + # 1. compute collocation polynomial for y and yp y = np.dot(self.Q, p) yp = np.dot(self.Qp, p) y += self.y_old[:, None] yp += self.yp_old[:, None] - if t.ndim == 0: - y = np.squeeze(y) - yp = np.squeeze(yp) - # compute derivative of interpolation polynomial - yp = np.dot(self.Q, dp) + # # 2. compute derivative of interpolation polynomial for y + # yp = np.dot(self.Q, dp) - # # compute both values by Horner's method: - # # https://cut-the-knot.org/Curriculum/Calculus/HornerMethod.shtml - # # https://math.stackexchange.com/questions/2139142/how-does-horner-method-evaluate-the-derivative-of-a-function + # # 3. compute both values by Horner's rule # y = np.zeros_like(y) - # y += self.Q[:, -1][:, None] # yp = np.zeros_like(y) - # for i in range(self.order + 1, 1, -1): + # for i in range(self.order, -1, -1): + # y = self.Q[:, i][:, None] + y * x[None, :] # yp = y + yp * x[None, :] - # y = self.Q[:, i - 1][:, None] + y * x[None, :] - - # # y += self.y_old[:, None] + # y = self.y_old[:, None] + y * x[None, :] # yp /= self.h + if t.ndim == 0: + y = np.squeeze(y) + yp = np.squeeze(yp) + return y, yp diff --git a/scipy_dae/integrate/_dae/tests/test_dae.py b/scipy_dae/integrate/_dae/tests/test_dae.py index 037b2a8..08693db 100644 --- a/scipy_dae/integrate/_dae/tests/test_dae.py +++ b/scipy_dae/integrate/_dae/tests/test_dae.py @@ -186,7 +186,7 @@ def test_integration_rational(vectorized, method, t_span, jac): e = compute_error(yc, yc_true, rtol, atol) assert_(np.all(e < 5)) - assert_allclose(res.sol(res.t)[0], res.y, rtol=1e-15, atol=1e-15) + assert_allclose(res.sol(res.t)[0], res.y, rtol=1e-14, atol=1e-14) parameters_stiff = ["BDF", "Radau"]