Skip to content

Commit

Permalink
simplified perturb_observations() argument from 'size' to 'ensemble_s…
Browse files Browse the repository at this point in the history
…ize'
  • Loading branch information
Tommy Odland committed Nov 24, 2023
1 parent 0456f26 commit 943edf6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
22 changes: 12 additions & 10 deletions src/iterative_ensemble_smoother/esmda.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import numbers
from abc import ABC
from typing import Tuple, Union
from typing import Union

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(
assert isinstance(self.C_D, np.ndarray) and self.C_D.ndim in (1, 2)

def perturb_observations(
self, *, size: Tuple[int, int], alpha: float
self, *, ensemble_size: int, alpha: float
) -> npt.NDArray[np.double]:
"""Create a matrix D with perturbed observations.
Expand All @@ -97,8 +97,10 @@ def perturb_observations(
Parameters
----------
size : Tuple[int, int]
The size, a tuple with (num_observations, ensemble_size).
ensemble_size : int
The ensemble size, i.e., the number of perturbed observations.
This represents the number of columns in the returned matrix, which
is of shape (num_observations, ensemble_size).
alpha : float
The covariance inflation factor. The sequence of alphas should
obey the equation sum_i (1/alpha_i) = 1. However, this is NOT enforced
Expand All @@ -118,11 +120,10 @@ def perturb_observations(

# Two cases, depending on whether C_D was given as 1D or 2D array
D: npt.NDArray[np.double]
# TODO: what if we have cov = [1, 2, 3] and we mask?
# we must pick out correct indices. 'size' is not enough...
D = self.observations.reshape(-1, 1) + np.sqrt(alpha) * sample_mvnormal(
C_dd_cholesky=self.C_D_L, rng=self.rng, size=size[1]
C_dd_cholesky=self.C_D_L, rng=self.rng, size=ensemble_size
)
assert D.shape == (len(self.observations), ensemble_size)

return D

Expand Down Expand Up @@ -283,8 +284,9 @@ def assimilate(
# if C_D = L L.T by the cholesky factorization, then drawing y from
# a zero cented normal means that y := L @ z, where z ~ norm(0, 1)
# Therefore, scaling C_D by alpha is equivalent to scaling L with sqrt(alpha)
size = (num_outputs, num_ensemble)
D = self.perturb_observations(size=size, alpha=self.alpha[self.iteration])
D = self.perturb_observations(
ensemble_size=num_ensemble, alpha=self.alpha[self.iteration]
)
assert D.shape == (num_outputs, num_ensemble)

# Line 2 (c) in the description of ES-MDA in the 2013 Emerick paper
Expand Down Expand Up @@ -348,7 +350,7 @@ def compute_transition_matrix(
# or
# X += X @ T

D = self.perturb_observations(size=Y.shape, alpha=alpha)
D = self.perturb_observations(ensemble_size=Y.shape[1], alpha=alpha)
inversion_func = self._inversion_methods[self.inversion]
return inversion_func(
alpha=alpha,
Expand Down
12 changes: 9 additions & 3 deletions tests/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def test_that_adaptive_localization_with_cutoff_1_equals_ensemble_prior(
Y_i = g(X_i)

# Create noise D - common to this ESMDA update
D_i = smoother.perturb_observations(size=Y_i.shape, alpha=alpha_i)
D_i = smoother.perturb_observations(
ensemble_size=Y_i.shape[1], alpha=alpha_i
)

# Update the relevant parameters and write to X (storage)
X_i = smoother.adaptive_assimilate(
Expand Down Expand Up @@ -141,7 +143,9 @@ def test_that_adaptive_localization_with_cutoff_0_equals_standard_ESMDA_update(
Y_i = g(X_i)

# Create noise D - common to this ESMDA update
D_i = smoother.perturb_observations(size=Y_i.shape, alpha=alpha_i)
D_i = smoother.perturb_observations(
ensemble_size=Y_i.shape[1], alpha=alpha_i
)

# Update the relevant parameters and write to X (storage)
X_i = smoother.adaptive_assimilate(
Expand Down Expand Up @@ -203,7 +207,9 @@ def test_that_posterior_generalized_variance_increases_in_cutoff(
Y_i = g(X_i)

# Create noise D - common to this ESMDA update
D_i = smoother.perturb_observations(size=Y_i.shape, alpha=alpha_i)
D_i = smoother.perturb_observations(
ensemble_size=Y_i.shape[1], alpha=alpha_i
)

cutoff_low, cutoff_high = cutoffs
assert cutoff_low <= cutoff_high
Expand Down

0 comments on commit 943edf6

Please sign in to comment.