Skip to content

Commit

Permalink
Enable vectorization and use num_jac now.
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasBreuling committed Jun 11, 2024
1 parent 77e97a9 commit 76093c4
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 68 deletions.
67 changes: 11 additions & 56 deletions pydaes/integrate/_dae/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import numpy as np
from scipy.sparse import issparse, csc_matrix, bmat
from scipy.sparse import issparse, csc_matrix
from scipy.optimize._numdiff import group_columns
from scipy.integrate._ivp.common import (
validate_max_step, validate_tol, num_jac,
validate_first_step,
)
from scipy.integrate._ivp.base import ConstantDenseOutput
from scipy.optimize._numdiff import approx_derivative
from scipy.linalg import lu_factor, lu_solve
from scipy.sparse import csc_matrix, issparse, eye
from scipy.sparse.linalg import splu
Expand Down Expand Up @@ -197,28 +196,11 @@ def fun_single(t, y, yp):
else:
fun_single = self._fun

def fun_vectorized_y(t, y, yp):
def fun_vectorized(t, y, yp):
f = np.empty_like(y)
for i, yi in enumerate(y.T):
f[:, i] = self._fun(t, yi, yp)
for i, (yi, ypi) in enumerate(zip(y.T, yp.T)):
f[:, i] = self._fun(t, yi, ypi)
return f

def fun_vectorized_yp(t, y, yp):
f = np.empty_like(yp)
for i, ypi in enumerate(yp.T):
f[:, i] = self._fun(t, y, ypi)
return f

# def fun_vectorized(t, y, yp):
# y = np.atleast_2d(y)
# yp = np.atleast_2d(yp)
# ny, my = y.shape
# nyp, myp = yp.shape
# n, m = max(ny, nyp), max(my, myp)
# f = np.empty((n, m))
# for i, (yi, ypi) in enumerate(zip(y.T, yp.T)):
# f[:, i] = self._fun(t, yi, ypi)
# return f

# composite function with z = (y, yp) for finite differences
def fun_composite(t, z):
Expand All @@ -231,16 +213,13 @@ def fun(t, y, yp):

self.fun = fun
self.fun_single = fun_single
# self.fun_vectorized = fun_vectorized
self.fun_vectorized_y = fun_vectorized_y
self.fun_vectorized_yp = fun_vectorized_yp
self.fun_vectorized = fun_vectorized
self.fun_composite = fun_composite
self.f = self.fun(self.t, self.y, self.yp)

self.direction = np.sign(t_bound - t0) if t_bound != t0 else 1
self.status = 'running'

# TODO: What is this factor?
self.jac_factor_y = None
self.jac_factor_yp = None
self.jac, self.Jy, self.Jyp = self._validate_jac(jac, jac_sparsity)
Expand Down Expand Up @@ -276,7 +255,6 @@ def _validate_jac(self, jac, sparsity):

if jac is None:
if sparsity is not None:
# raise NotImplementedError("Exploiting sparsity structure is not supported yet.")
sparsity_y, sparsity_yp = sparsity
if issparse(sparsity_y):
sparsity_y = csc_matrix(sparsity_y)
Expand All @@ -286,40 +264,17 @@ def _validate_jac(self, jac, sparsity):
groups_yp = group_columns(sparsity_yp)
sparsity_y = (sparsity_y, groups_y)
sparsity_yp = (sparsity_yp, groups_yp)
sparsity = bmat([[sparsity_y], [sparsity_yp]])
else:
sparsity_y, sparsity_yp = None, None

def jac_wrapped(t, y, yp, f):
self.njev += 1
# J, self.jac_factor = num_jac(self.fun_vectorized, t, y, f,
# self.atol, self.jac_factor,
# sparsity)
Jy, self.jac_factor_y = num_jac(lambda t, y: self.fun_vectorized_y(t, y, yp),
t, y, f, self.atol, self.jac_factor_y,
sparsity_y)
Jyp, self.jac_factor_yp = num_jac(lambda t, yp: self.fun_vectorized_yp(t, y, yp),
t, yp, f, self.atol, self.jac_factor_yp,
sparsity_y)



# z = np.concatenate((y, yp))
# J = approx_derivative(lambda z: self.fun_composite(t, z),
# z, method="2-point", f0=f,
# sparsity=sparsity)
# J = J.reshape((self.n, 2 * self.n))
# Jy, Jyp = J[:, :self.n], J[:, self.n:]




# Jy, self.jac_factor_y = num_jac(
# lambda t, y: self.fun_vectorized(t, y, yp),
# t, y, f, self.atol, self.jac_factor_y, sparsity_y)
# Jyp, self.jac_factor_yp = num_jac(
# lambda t, yp: self.fun_vectorized(t, y, yp),
# t, yp, f, self.atol, self.jac_factor_yp, sparsity_yp)
Jy, self.jac_factor_y = num_jac(
lambda t, y: self.fun_vectorized(t, y, np.tile(yp[:, None], self.n)),
t, y, f, self.atol, self.jac_factor_y, sparsity_y)
Jyp, self.jac_factor_yp = num_jac(
lambda t, yp: self.fun_vectorized(t, np.tile(y[:, None], self.n), yp),
t, yp, f, self.atol, self.jac_factor_yp, sparsity_y)

return Jy, Jyp

Expand Down
20 changes: 8 additions & 12 deletions test/test_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def F_rational(t, y, yp):


def F_rational_vectorized(t, y, yp):
return yp[:, None] - fun_rational_vectorized(t, y)
return yp - fun_rational_vectorized(t, y)


def J_rational(t, y, yp):
Expand Down Expand Up @@ -70,8 +70,7 @@ def J_complex_sparse(t, y, yp):

parameters_linear = product(
["BDF"], # method
# [None, J_linear, J_linear_sparse] # jac
[None] # jac
[None, J_linear, J_linear_sparse] # jac
)
@pytest.mark.parametrize("method, jac", parameters_linear)
def test_integration_const_jac(method, jac):
Expand All @@ -95,7 +94,6 @@ def test_integration_const_jac(method, jac):
assert_(np.all(e < 5))


# TODO: Get finite differences working with complex numbers.
parameters_complex = product(
["BDF"], # method
[None, J_complex, J_complex_sparse] # jac
Expand All @@ -106,8 +104,6 @@ def test_integration_complex(method, jac):
atol = 1e-6
y0 = np.array([0.5 + 1j])
yp0 = fun_complex(0, y0)
# print(F_complex(0, y0, yp0))
# exit()
t_span = [0, 1]
tc = np.linspace(t_span[0], t_span[1])

Expand Down Expand Up @@ -141,10 +137,8 @@ def test_integration_complex(method, jac):
assert np.all(e < 5)


# TODO: Vectorization is not supported yet!
parameters_rational = product(
[False], # vectorized
# [False, True], # vectorized
[False, True], # vectorized
["BDF"], # method
[[5, 9], [5, 1]], # t_span
[None, J_rational, J_rational_sparse] # jac
Expand All @@ -155,16 +149,15 @@ def test_integration_rational(vectorized, method, t_span, jac):
atol = 1e-6
y0 = [1/3, 2/9]
yp0 = fun_rational(5, y0)
# print(F_rational(5, y0, yp0))
# exit()

if vectorized:
fun = F_rational_vectorized
else:
fun = F_rational

res = solve_dae(fun, t_span, y0, yp0, rtol=rtol, atol=atol,
method=method, dense_output=True, jac=jac)
method=method, dense_output=True, jac=jac,
vectorized=vectorized)

assert_equal(res.t[0], t_span[0])
assert_(res.t_events is None)
Expand Down Expand Up @@ -245,3 +238,6 @@ def F_robertson(t, state, statep):

for params in parameters_rational:
test_integration_rational(*params)

for params in parameters_stiff:
test_integration_stiff(*params)

0 comments on commit 76093c4

Please sign in to comment.