@@ -147,10 +147,11 @@ def to_derivative(self, i, standard_deviation):
147
147
148
148
149
149
class DenseConditional (ConditionalBackend ):
150
- def __init__ (self , ode_shape , num_derivatives , unravel ):
150
+ def __init__ (self , ode_shape , num_derivatives , unravel , flat_shape ):
151
151
self .ode_shape = ode_shape
152
152
self .num_derivatives = num_derivatives
153
153
self .unravel = unravel
154
+ self .flat_shape = flat_shape
154
155
155
156
def apply (self , x , conditional , / ):
156
157
matrix , noise = conditional
@@ -178,8 +179,6 @@ def revert(self, rv, conditional, /):
178
179
mean , cholesky = rv .mean , rv .cholesky
179
180
180
181
# QR-decomposition
181
- # (todo: rename revert_conditional_noisefree to
182
- # revert_transformation_cov_sqrt())
183
182
r_obs , (r_cor , gain ) = cholesky_util .revert_conditional (
184
183
R_X_F = (matrix @ cholesky ).T , R_X = cholesky .T , R_YX = noise .cholesky .T
185
184
)
@@ -208,8 +207,7 @@ def ibm_transitions(self, *, output_scale):
208
207
A = np .kron (a , eye_d )
209
208
Q = np .kron (q_sqrtm , eye_d )
210
209
211
- ndim = d * (self .num_derivatives + 1 )
212
- q0 = np .zeros ((ndim ,))
210
+ q0 = np .zeros (self .flat_shape )
213
211
noise = _normal .Normal (q0 , Q )
214
212
215
213
precon_fun = preconditioner_prepare (num_derivatives = self .num_derivatives )
@@ -230,34 +228,25 @@ def preconditioner_apply(self, cond, p, p_inv, /):
230
228
return Conditional (A , noise )
231
229
232
230
def to_derivative (self , i , standard_deviation ):
233
- a0 = functools .partial (self ._select , idx_or_slice = i )
231
+ x = np .zeros (self .flat_shape )
232
+
233
+ def select (a ):
234
+ return self .unravel (a )[i ]
235
+
236
+ linop = functools .jacrev (select )(x )
234
237
235
238
(d ,) = self .ode_shape
236
239
bias = np .zeros ((d ,))
237
240
eye = np .eye (d )
238
241
noise = _normal .Normal (bias , standard_deviation * eye )
239
-
240
- x = np .zeros (((self .num_derivatives + 1 ) * d ,))
241
- linop = _jac_materialize (lambda s , _p : self ._autobatch_linop (a0 )(s ), inputs = x )
242
242
return Conditional (linop , noise )
243
243
244
- def _select (self , x , / , idx_or_slice ):
245
- return self .unravel (x )[idx_or_slice ]
246
-
247
- @staticmethod
248
- def _autobatch_linop (fun ):
249
- def fun_ (x ):
250
- if np .ndim (x ) > 1 :
251
- return functools .vmap (fun_ , in_axes = 1 , out_axes = 1 )(x )
252
- return fun (x )
253
-
254
- return fun_
255
-
256
244
257
245
class IsotropicConditional (ConditionalBackend ):
258
- def __init__ (self , * , ode_shape , num_derivatives ):
246
+ def __init__ (self , * , ode_shape , num_derivatives , unravel_tree ):
259
247
self .ode_shape = ode_shape
260
248
self .num_derivatives = num_derivatives
249
+ self .unravel_tree = unravel_tree
261
250
262
251
def apply (self , x , conditional , / ):
263
252
A , noise = conditional
@@ -332,22 +321,24 @@ def preconditioner_apply(self, cond, p, p_inv, /):
332
321
return Conditional (A_new , noise )
333
322
334
323
def to_derivative (self , i , standard_deviation ):
335
- def A (x ):
336
- return x [[i ], ...]
324
+ def select (a ):
325
+ return tree_util .ravel_pytree (self .unravel_tree (a )[i ])[0 ]
326
+
327
+ m = np .zeros ((self .num_derivatives + 1 ,))
328
+ linop = functools .jacrev (select )(m )
337
329
338
330
bias = np .zeros (self .ode_shape )
339
331
eye = np .eye (1 )
340
332
noise = _normal .Normal (bias , standard_deviation * eye )
341
333
342
- m = np .zeros ((self .num_derivatives + 1 ,))
343
- linop = _jac_materialize (lambda s , _p : A (s ), inputs = m )
344
334
return Conditional (linop , noise )
345
335
346
336
347
337
class BlockDiagConditional (ConditionalBackend ):
348
- def __init__ (self , * , ode_shape , num_derivatives ):
338
+ def __init__ (self , * , ode_shape , num_derivatives , unravel_tree ):
349
339
self .ode_shape = ode_shape
350
340
self .num_derivatives = num_derivatives
341
+ self .unravel_tree = unravel_tree
351
342
352
343
def apply (self , x , conditional , / ):
353
344
if np .ndim (x ) == 1 :
@@ -434,15 +425,11 @@ def preconditioner_apply(self, cond, p, p_inv, /):
434
425
return Conditional (A_new , noise )
435
426
436
427
def to_derivative (self , i , standard_deviation ):
437
- def A (x ):
438
- return x [[i ], ...]
439
-
440
- @functools .vmap
441
- def lo (y ):
442
- return _jac_materialize (lambda s , _p : A (s ), inputs = y )
428
+ def select (a ):
429
+ return tree_util .ravel_pytree (self .unravel_tree (a )[i ])[0 ]
443
430
444
431
x = np .zeros ((* self .ode_shape , self .num_derivatives + 1 ))
445
- linop = lo (x )
432
+ linop = functools . vmap ( functools . jacrev ( select )) (x )
446
433
447
434
bias = np .zeros ((* self .ode_shape , 1 ))
448
435
eye = np .ones ((* self .ode_shape , 1 , 1 )) * np .eye (1 )[None , ...]
@@ -494,7 +481,3 @@ def _batch_gram(k, /):
494
481
495
482
def _binom (n , k ):
496
483
return np .factorial (n ) / (np .factorial (n - k ) * np .factorial (k ))
497
-
498
-
499
- def _jac_materialize (func , / , * , inputs , params = None ):
500
- return functools .jacrev (lambda v : func (v , params ))(inputs )
0 commit comments