From c3218a137e32b72c9e3f5b78e6e7651640008786 Mon Sep 17 00:00:00 2001 From: Feda Curic Date: Fri, 12 Jan 2024 14:15:22 +0100 Subject: [PATCH] Add overwrite flag and return references --- .../experimental.py | 12 +++++++++- tests/test_experimental.py | 23 ++++++++----------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/iterative_ensemble_smoother/experimental.py b/src/iterative_ensemble_smoother/experimental.py index efdb2f68..64013af2 100644 --- a/src/iterative_ensemble_smoother/experimental.py +++ b/src/iterative_ensemble_smoother/experimental.py @@ -162,7 +162,7 @@ def _cov_to_corr_inplace( msg += f"The min and max values are: {cov_XY.min()} and {cov_XY.max()}" warnings.warn(msg) - np.clip(cov_XY, a_min=-1, a_max=1, out=cov_XY) + return np.clip(cov_XY, a_min=-1, a_max=1, out=cov_XY) def _corr_to_cov_inplace( self, @@ -179,6 +179,7 @@ def _corr_to_cov_inplace( # Multiply each element of corr_XY by the corresponding standard deviations corr_XY *= stds_X[:, np.newaxis] corr_XY *= stds_Y[np.newaxis, :] + return corr_XY def assimilate( self, @@ -186,6 +187,7 @@ def assimilate( X: npt.NDArray[np.double], Y: npt.NDArray[np.double], D: npt.NDArray[np.double], + overwrite: bool = False, alpha: float, correlation_threshold: Union[Callable[[int], float], float, None] = None, cov_YY: Optional[npt.NDArray[np.double]] = None, @@ -219,6 +221,10 @@ def assimilate( The covariance inflation factor. The sequence of alphas should obey the equation sum_i (1/alpha_i) = 1. However, this is NOT enforced in this method. The user/caller is responsible for this. + overwrite: bool + If True, X will be overwritten and mutated. + If False, the method will not mutate inputs in any way. + Settings this to True saves memory. correlation_threshold : callable or float or None Either a callable with signature f(ensemble_size) -> float, or a float in the range [0, 1]. Entries in the covariance matrix that @@ -251,6 +257,10 @@ def assimilate( "`correlation_threshold` must be a callable or a float in [0, 1]" ) + # Do not overwrite input arguments + if not overwrite: + X = np.copy(X) + # Create `correlation_threshold` if the argument is a float if is_float: corr_threshold: float = correlation_threshold # type: ignore diff --git a/tests/test_experimental.py b/tests/test_experimental.py index 33ff7d7e..7a37986e 100644 --- a/tests/test_experimental.py +++ b/tests/test_experimental.py @@ -103,7 +103,7 @@ def test_that_adaptive_localization_with_cutoff_1_equals_ensemble_prior( ) # Update the relevant parameters and write to X - smoother.assimilate( + X_i = smoother.assimilate( X=X_i, Y=Y_i, D=D_i, @@ -149,6 +149,7 @@ def test_that_adaptive_localization_with_cutoff_0_equals_standard_ESMDA_update( X=X_i, Y=Y_i, D=D_i, + overwrite=True, alpha=alpha_i, correlation_threshold=lambda ensemble_size: 0, ) @@ -201,8 +202,6 @@ def test_that_posterior_generalized_variance_increases_in_cutoff( ) X_i = np.copy(X) - X_i_low_cutoff = np.copy(X) - X_i_high_cutoff = np.copy(X) for _, alpha_i in enumerate(alpha, 1): # Run forward model Y_i = g(X_i) @@ -215,15 +214,15 @@ def test_that_posterior_generalized_variance_increases_in_cutoff( cutoff_low, cutoff_high = cutoffs assert cutoff_low <= cutoff_high - smoother.assimilate( - X=X_i_low_cutoff, + X_i_low_cutoff = smoother.assimilate( + X=X_i, Y=Y_i, D=D_i, alpha=alpha_i, correlation_threshold=lambda ensemble_size: cutoff_low, ) - smoother.assimilate( - X=X_i_high_cutoff, + X_i_high_cutoff = smoother.assimilate( + X=X_i, Y=Y_i, D=D_i, alpha=alpha_i, @@ -413,8 +412,6 @@ def test_that_cov_YY_can_be_computed_outside_of_assimilate( ) X_i = np.copy(X) - X_i1 = np.copy(X) - X_i2 = np.copy(X) for i, alpha_i in enumerate(alpha, 1): print(f"ESMDA iteration {i} with alpha_i={alpha_i}") @@ -427,16 +424,16 @@ def test_that_cov_YY_can_be_computed_outside_of_assimilate( ) # Update the parameters without using pre-computed cov_YY - smoother.assimilate( - X=X_i1, + X_i1 = smoother.assimilate( + X=X_i, Y=Y_i, D=D_i, alpha=alpha_i, ) # Update the parameters using pre-computed cov_YY - smoother.assimilate( - X=X_i2, + X_i2 = smoother.assimilate( + X=X_i, Y=Y_i, D=D_i, alpha=alpha_i,