Skip to content

Commit c5260d3

Browse files
authored
Internal: Remove private select-style functions because we have easier ways of accessing derivatives now (#799)
* Remove self._select from DenseConditional * Simplify select() for isotropic and blockdiag factorisations, too * Delete marginal_nth_derivative because it hasn't been used anywhere
1 parent 2a1b2c1 commit c5260d3

File tree

3 files changed

+28
-101
lines changed

3 files changed

+28
-101
lines changed

probdiffeq/impl/_conditional.py

Lines changed: 21 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,11 @@ def to_derivative(self, i, standard_deviation):
147147

148148

149149
class DenseConditional(ConditionalBackend):
150-
def __init__(self, ode_shape, num_derivatives, unravel):
150+
def __init__(self, ode_shape, num_derivatives, unravel, flat_shape):
151151
self.ode_shape = ode_shape
152152
self.num_derivatives = num_derivatives
153153
self.unravel = unravel
154+
self.flat_shape = flat_shape
154155

155156
def apply(self, x, conditional, /):
156157
matrix, noise = conditional
@@ -178,8 +179,6 @@ def revert(self, rv, conditional, /):
178179
mean, cholesky = rv.mean, rv.cholesky
179180

180181
# QR-decomposition
181-
# (todo: rename revert_conditional_noisefree to
182-
# revert_transformation_cov_sqrt())
183182
r_obs, (r_cor, gain) = cholesky_util.revert_conditional(
184183
R_X_F=(matrix @ cholesky).T, R_X=cholesky.T, R_YX=noise.cholesky.T
185184
)
@@ -208,8 +207,7 @@ def ibm_transitions(self, *, output_scale):
208207
A = np.kron(a, eye_d)
209208
Q = np.kron(q_sqrtm, eye_d)
210209

211-
ndim = d * (self.num_derivatives + 1)
212-
q0 = np.zeros((ndim,))
210+
q0 = np.zeros(self.flat_shape)
213211
noise = _normal.Normal(q0, Q)
214212

215213
precon_fun = preconditioner_prepare(num_derivatives=self.num_derivatives)
@@ -230,34 +228,25 @@ def preconditioner_apply(self, cond, p, p_inv, /):
230228
return Conditional(A, noise)
231229

232230
def to_derivative(self, i, standard_deviation):
233-
a0 = functools.partial(self._select, idx_or_slice=i)
231+
x = np.zeros(self.flat_shape)
232+
233+
def select(a):
234+
return self.unravel(a)[i]
235+
236+
linop = functools.jacrev(select)(x)
234237

235238
(d,) = self.ode_shape
236239
bias = np.zeros((d,))
237240
eye = np.eye(d)
238241
noise = _normal.Normal(bias, standard_deviation * eye)
239-
240-
x = np.zeros(((self.num_derivatives + 1) * d,))
241-
linop = _jac_materialize(lambda s, _p: self._autobatch_linop(a0)(s), inputs=x)
242242
return Conditional(linop, noise)
243243

244-
def _select(self, x, /, idx_or_slice):
245-
return self.unravel(x)[idx_or_slice]
246-
247-
@staticmethod
248-
def _autobatch_linop(fun):
249-
def fun_(x):
250-
if np.ndim(x) > 1:
251-
return functools.vmap(fun_, in_axes=1, out_axes=1)(x)
252-
return fun(x)
253-
254-
return fun_
255-
256244

257245
class IsotropicConditional(ConditionalBackend):
258-
def __init__(self, *, ode_shape, num_derivatives):
246+
def __init__(self, *, ode_shape, num_derivatives, unravel_tree):
259247
self.ode_shape = ode_shape
260248
self.num_derivatives = num_derivatives
249+
self.unravel_tree = unravel_tree
261250

262251
def apply(self, x, conditional, /):
263252
A, noise = conditional
@@ -332,22 +321,24 @@ def preconditioner_apply(self, cond, p, p_inv, /):
332321
return Conditional(A_new, noise)
333322

334323
def to_derivative(self, i, standard_deviation):
335-
def A(x):
336-
return x[[i], ...]
324+
def select(a):
325+
return tree_util.ravel_pytree(self.unravel_tree(a)[i])[0]
326+
327+
m = np.zeros((self.num_derivatives + 1,))
328+
linop = functools.jacrev(select)(m)
337329

338330
bias = np.zeros(self.ode_shape)
339331
eye = np.eye(1)
340332
noise = _normal.Normal(bias, standard_deviation * eye)
341333

342-
m = np.zeros((self.num_derivatives + 1,))
343-
linop = _jac_materialize(lambda s, _p: A(s), inputs=m)
344334
return Conditional(linop, noise)
345335

346336

347337
class BlockDiagConditional(ConditionalBackend):
348-
def __init__(self, *, ode_shape, num_derivatives):
338+
def __init__(self, *, ode_shape, num_derivatives, unravel_tree):
349339
self.ode_shape = ode_shape
350340
self.num_derivatives = num_derivatives
341+
self.unravel_tree = unravel_tree
351342

352343
def apply(self, x, conditional, /):
353344
if np.ndim(x) == 1:
@@ -434,15 +425,11 @@ def preconditioner_apply(self, cond, p, p_inv, /):
434425
return Conditional(A_new, noise)
435426

436427
def to_derivative(self, i, standard_deviation):
437-
def A(x):
438-
return x[[i], ...]
439-
440-
@functools.vmap
441-
def lo(y):
442-
return _jac_materialize(lambda s, _p: A(s), inputs=y)
428+
def select(a):
429+
return tree_util.ravel_pytree(self.unravel_tree(a)[i])[0]
443430

444431
x = np.zeros((*self.ode_shape, self.num_derivatives + 1))
445-
linop = lo(x)
432+
linop = functools.vmap(functools.jacrev(select))(x)
446433

447434
bias = np.zeros((*self.ode_shape, 1))
448435
eye = np.ones((*self.ode_shape, 1, 1)) * np.eye(1)[None, ...]
@@ -494,7 +481,3 @@ def _batch_gram(k, /):
494481

495482
def _binom(n, k):
496483
return np.factorial(n) / (np.factorial(n - k) * np.factorial(k))
497-
498-
499-
def _jac_materialize(func, /, *, inputs, params=None):
500-
return functools.jacrev(lambda v: func(v, params))(inputs)

probdiffeq/impl/_stats.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,6 @@ def rescale_cholesky(self, rv, factor, /):
4242
def qoi(self, rv):
4343
raise NotImplementedError
4444

45-
@abc.abstractmethod
46-
def marginal_nth_derivative(self, rv):
47-
raise NotImplementedError
48-
4945
@abc.abstractmethod
5046
def qoi_from_sample(self, sample, /):
5147
raise NotImplementedError
@@ -105,37 +101,11 @@ def to_multivariate_normal(self, rv):
105101
def qoi(self, rv):
106102
return self.qoi_from_sample(rv.mean)
107103

108-
def marginal_nth_derivative(self, rv, i):
109-
if rv.mean.ndim > 1:
110-
return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))(
111-
rv, i
112-
)
113-
114-
m = self._select(rv.mean, i)
115-
c = functools.vmap(self._select, in_axes=(1, None), out_axes=1)(rv.cholesky, i)
116-
c = cholesky_util.triu_via_qr(c.T)
117-
return _normal.Normal(m, c.T)
118-
119104
def qoi_from_sample(self, sample, /):
120105
if np.ndim(sample) > 1:
121106
return functools.vmap(self.qoi_from_sample)(sample)
122107
return self.unravel(sample)
123108

124-
def _select(self, x, /, idx_or_slice):
125-
x_reshaped = np.reshape(x, (-1, *self.ode_shape), order="F")
126-
if isinstance(idx_or_slice, int) and idx_or_slice > x_reshaped.shape[0]:
127-
raise ValueError
128-
return x_reshaped[idx_or_slice]
129-
130-
@staticmethod
131-
def _autobatch_linop(fun):
132-
def fun_(x):
133-
if np.ndim(x) > 1:
134-
return functools.vmap(fun_, in_axes=1, out_axes=1)(x)
135-
return fun(x)
136-
137-
return fun_
138-
139109
def update_mean(self, mean, x, /, num):
140110
nominator = cholesky_util.sqrt_sum_square_scalar(np.sqrt(num) * mean, x)
141111
denominator = np.sqrt(num + 1)
@@ -198,19 +168,6 @@ def to_multivariate_normal(self, rv):
198168
mean = rv.mean.reshape((-1,), order="F")
199169
return (mean, cov)
200170

201-
def marginal_nth_derivative(self, rv, i):
202-
if np.ndim(rv.mean) > 2:
203-
return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))(
204-
rv, i
205-
)
206-
207-
if i > np.shape(rv.mean)[0]:
208-
raise ValueError
209-
210-
mean = rv.mean[i, :]
211-
cholesky = cholesky_util.triu_via_qr(rv.cholesky[[i], :].T).T
212-
return _normal.Normal(mean, cholesky)
213-
214171
def qoi(self, rv):
215172
return self.qoi_from_sample(rv.mean)
216173

@@ -287,22 +244,6 @@ def qoi_from_sample(self, sample, /):
287244
return functools.vmap(self.qoi_from_sample)(sample)
288245
return self.unravel(sample)
289246

290-
def marginal_nth_derivative(self, rv, i):
291-
if np.ndim(rv.mean) > 2:
292-
return functools.vmap(self.marginal_nth_derivative, in_axes=(0, None))(
293-
rv, i
294-
)
295-
296-
if i > np.shape(rv.mean)[0]:
297-
raise ValueError
298-
299-
mean = rv.mean[:, i]
300-
cholesky = functools.vmap(cholesky_util.triu_via_qr)(
301-
(rv.cholesky[:, i, :])[..., None]
302-
)
303-
cholesky = np.transpose(cholesky, axes=(0, 2, 1))
304-
return _normal.Normal(mean, cholesky)
305-
306247
def update_mean(self, mean, x, /, num):
307248
if np.ndim(mean) > 0:
308249
assert np.shape(mean) == np.shape(x)

probdiffeq/impl/impl.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def choose(which: str, /, *, tcoeffs_like) -> FactImpl:
4040

4141
def _select_dense(*, tcoeffs_like) -> FactImpl:
4242
ode_shape = tcoeffs_like[0].shape
43-
_, unravel = tree_util.ravel_pytree(tcoeffs_like)
43+
flat, unravel = tree_util.ravel_pytree(tcoeffs_like)
4444

4545
num_derivatives = len(tcoeffs_like) - 1
4646

@@ -49,7 +49,10 @@ def _select_dense(*, tcoeffs_like) -> FactImpl:
4949
linearise = _linearise.DenseLinearisation(ode_shape=ode_shape, unravel=unravel)
5050
stats = _stats.DenseStats(ode_shape=ode_shape, unravel=unravel)
5151
conditional = _conditional.DenseConditional(
52-
ode_shape=ode_shape, num_derivatives=num_derivatives, unravel=unravel
52+
ode_shape=ode_shape,
53+
num_derivatives=num_derivatives,
54+
unravel=unravel,
55+
flat_shape=flat.shape,
5356
)
5457
transform = _conditional.DenseTransform()
5558
return FactImpl(
@@ -76,7 +79,7 @@ def _select_isotropic(*, tcoeffs_like) -> FactImpl:
7679
stats = _stats.IsotropicStats(ode_shape=ode_shape, unravel=unravel)
7780
linearise = _linearise.IsotropicLinearisation()
7881
conditional = _conditional.IsotropicConditional(
79-
ode_shape=ode_shape, num_derivatives=num_derivatives
82+
ode_shape=ode_shape, num_derivatives=num_derivatives, unravel_tree=unravel_tree
8083
)
8184
transform = _conditional.IsotropicTransform()
8285
return FactImpl(
@@ -103,7 +106,7 @@ def _select_blockdiag(*, tcoeffs_like) -> FactImpl:
103106
stats = _stats.BlockDiagStats(ode_shape=ode_shape, unravel=unravel)
104107
linearise = _linearise.BlockDiagLinearisation()
105108
conditional = _conditional.BlockDiagConditional(
106-
ode_shape=ode_shape, num_derivatives=num_derivatives
109+
ode_shape=ode_shape, num_derivatives=num_derivatives, unravel_tree=unravel_tree
107110
)
108111
transform = _conditional.BlockDiagTransform(ode_shape=ode_shape)
109112
return FactImpl(

0 commit comments

Comments
 (0)