From 943edf69f45b1d0d185d47c15448c2ad11bf461d Mon Sep 17 00:00:00 2001 From: Tommy Odland Date: Fri, 24 Nov 2023 08:40:06 +0100 Subject: [PATCH] simplified perturb_observations() argument from 'size' to 'ensemble_size' --- src/iterative_ensemble_smoother/esmda.py | 22 ++++++++++++---------- tests/test_experimental.py | 12 +++++++++--- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/iterative_ensemble_smoother/esmda.py b/src/iterative_ensemble_smoother/esmda.py index a72dafc9..cff78f5d 100644 --- a/src/iterative_ensemble_smoother/esmda.py +++ b/src/iterative_ensemble_smoother/esmda.py @@ -24,7 +24,7 @@ import numbers from abc import ABC -from typing import Tuple, Union +from typing import Union import numpy as np import numpy.typing as npt @@ -88,7 +88,7 @@ def __init__( assert isinstance(self.C_D, np.ndarray) and self.C_D.ndim in (1, 2) def perturb_observations( - self, *, size: Tuple[int, int], alpha: float + self, *, ensemble_size: int, alpha: float ) -> npt.NDArray[np.double]: """Create a matrix D with perturbed observations. @@ -97,8 +97,10 @@ def perturb_observations( Parameters ---------- - size : Tuple[int, int] - The size, a tuple with (num_observations, ensemble_size). + ensemble_size : int + The ensemble size, i.e., the number of perturbed observations. + This represents the number of columns in the returned matrix, which + is of shape (num_observations, ensemble_size). alpha : float The covariance inflation factor. The sequence of alphas should obey the equation sum_i (1/alpha_i) = 1. However, this is NOT enforced @@ -118,11 +120,10 @@ def perturb_observations( # Two cases, depending on whether C_D was given as 1D or 2D array D: npt.NDArray[np.double] - # TODO: what if we have cov = [1, 2, 3] and we mask? - # we must pick out correct indices. 'size' is not enough... D = self.observations.reshape(-1, 1) + np.sqrt(alpha) * sample_mvnormal( - C_dd_cholesky=self.C_D_L, rng=self.rng, size=size[1] + C_dd_cholesky=self.C_D_L, rng=self.rng, size=ensemble_size ) + assert D.shape == (len(self.observations), ensemble_size) return D @@ -283,8 +284,9 @@ def assimilate( # if C_D = L L.T by the cholesky factorization, then drawing y from # a zero cented normal means that y := L @ z, where z ~ norm(0, 1) # Therefore, scaling C_D by alpha is equivalent to scaling L with sqrt(alpha) - size = (num_outputs, num_ensemble) - D = self.perturb_observations(size=size, alpha=self.alpha[self.iteration]) + D = self.perturb_observations( + ensemble_size=num_ensemble, alpha=self.alpha[self.iteration] + ) assert D.shape == (num_outputs, num_ensemble) # Line 2 (c) in the description of ES-MDA in the 2013 Emerick paper @@ -348,7 +350,7 @@ def compute_transition_matrix( # or # X += X @ T - D = self.perturb_observations(size=Y.shape, alpha=alpha) + D = self.perturb_observations(ensemble_size=Y.shape[1], alpha=alpha) inversion_func = self._inversion_methods[self.inversion] return inversion_func( alpha=alpha, diff --git a/tests/test_experimental.py b/tests/test_experimental.py index e8efd3c7..7a28878d 100644 --- a/tests/test_experimental.py +++ b/tests/test_experimental.py @@ -99,7 +99,9 @@ def test_that_adaptive_localization_with_cutoff_1_equals_ensemble_prior( Y_i = g(X_i) # Create noise D - common to this ESMDA update - D_i = smoother.perturb_observations(size=Y_i.shape, alpha=alpha_i) + D_i = smoother.perturb_observations( + ensemble_size=Y_i.shape[1], alpha=alpha_i + ) # Update the relevant parameters and write to X (storage) X_i = smoother.adaptive_assimilate( @@ -141,7 +143,9 @@ def test_that_adaptive_localization_with_cutoff_0_equals_standard_ESMDA_update( Y_i = g(X_i) # Create noise D - common to this ESMDA update - D_i = smoother.perturb_observations(size=Y_i.shape, alpha=alpha_i) + D_i = smoother.perturb_observations( + ensemble_size=Y_i.shape[1], alpha=alpha_i + ) # Update the relevant parameters and write to X (storage) X_i = smoother.adaptive_assimilate( @@ -203,7 +207,9 @@ def test_that_posterior_generalized_variance_increases_in_cutoff( Y_i = g(X_i) # Create noise D - common to this ESMDA update - D_i = smoother.perturb_observations(size=Y_i.shape, alpha=alpha_i) + D_i = smoother.perturb_observations( + ensemble_size=Y_i.shape[1], alpha=alpha_i + ) cutoff_low, cutoff_high = cutoffs assert cutoff_low <= cutoff_high