Skip to content

Commit

Permalink
Add overwrite flag and return references
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Jan 12, 2024
1 parent e55a472 commit c3218a1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
12 changes: 11 additions & 1 deletion src/iterative_ensemble_smoother/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _cov_to_corr_inplace(
msg += f"The min and max values are: {cov_XY.min()} and {cov_XY.max()}"
warnings.warn(msg)

np.clip(cov_XY, a_min=-1, a_max=1, out=cov_XY)
return np.clip(cov_XY, a_min=-1, a_max=1, out=cov_XY)

def _corr_to_cov_inplace(
self,
Expand All @@ -179,13 +179,15 @@ def _corr_to_cov_inplace(
# Multiply each element of corr_XY by the corresponding standard deviations
corr_XY *= stds_X[:, np.newaxis]
corr_XY *= stds_Y[np.newaxis, :]
return corr_XY

def assimilate(
self,
*,
X: npt.NDArray[np.double],
Y: npt.NDArray[np.double],
D: npt.NDArray[np.double],
overwrite: bool = False,
alpha: float,
correlation_threshold: Union[Callable[[int], float], float, None] = None,
cov_YY: Optional[npt.NDArray[np.double]] = None,
Expand Down Expand Up @@ -219,6 +221,10 @@ def assimilate(
The covariance inflation factor. The sequence of alphas should
obey the equation sum_i (1/alpha_i) = 1. However, this is NOT
enforced in this method. The user/caller is responsible for this.
overwrite: bool
If True, X will be overwritten and mutated.
If False, the method will not mutate inputs in any way.
Settings this to True saves memory.
correlation_threshold : callable or float or None
Either a callable with signature f(ensemble_size) -> float, or a
float in the range [0, 1]. Entries in the covariance matrix that
Expand Down Expand Up @@ -251,6 +257,10 @@ def assimilate(
"`correlation_threshold` must be a callable or a float in [0, 1]"
)

# Do not overwrite input arguments
if not overwrite:
X = np.copy(X)

# Create `correlation_threshold` if the argument is a float
if is_float:
corr_threshold: float = correlation_threshold # type: ignore
Expand Down
23 changes: 10 additions & 13 deletions tests/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_that_adaptive_localization_with_cutoff_1_equals_ensemble_prior(
)

# Update the relevant parameters and write to X
smoother.assimilate(
X_i = smoother.assimilate(
X=X_i,
Y=Y_i,
D=D_i,
Expand Down Expand Up @@ -149,6 +149,7 @@ def test_that_adaptive_localization_with_cutoff_0_equals_standard_ESMDA_update(
X=X_i,
Y=Y_i,
D=D_i,
overwrite=True,
alpha=alpha_i,
correlation_threshold=lambda ensemble_size: 0,
)
Expand Down Expand Up @@ -201,8 +202,6 @@ def test_that_posterior_generalized_variance_increases_in_cutoff(
)

X_i = np.copy(X)
X_i_low_cutoff = np.copy(X)
X_i_high_cutoff = np.copy(X)
for _, alpha_i in enumerate(alpha, 1):
# Run forward model
Y_i = g(X_i)
Expand All @@ -215,15 +214,15 @@ def test_that_posterior_generalized_variance_increases_in_cutoff(
cutoff_low, cutoff_high = cutoffs
assert cutoff_low <= cutoff_high

smoother.assimilate(
X=X_i_low_cutoff,
X_i_low_cutoff = smoother.assimilate(
X=X_i,
Y=Y_i,
D=D_i,
alpha=alpha_i,
correlation_threshold=lambda ensemble_size: cutoff_low,
)
smoother.assimilate(
X=X_i_high_cutoff,
X_i_high_cutoff = smoother.assimilate(
X=X_i,
Y=Y_i,
D=D_i,
alpha=alpha_i,
Expand Down Expand Up @@ -413,8 +412,6 @@ def test_that_cov_YY_can_be_computed_outside_of_assimilate(
)

X_i = np.copy(X)
X_i1 = np.copy(X)
X_i2 = np.copy(X)
for i, alpha_i in enumerate(alpha, 1):
print(f"ESMDA iteration {i} with alpha_i={alpha_i}")

Expand All @@ -427,16 +424,16 @@ def test_that_cov_YY_can_be_computed_outside_of_assimilate(
)

# Update the parameters without using pre-computed cov_YY
smoother.assimilate(
X=X_i1,
X_i1 = smoother.assimilate(
X=X_i,
Y=Y_i,
D=D_i,
alpha=alpha_i,
)

# Update the parameters using pre-computed cov_YY
smoother.assimilate(
X=X_i2,
X_i2 = smoother.assimilate(
X=X_i,
Y=Y_i,
D=D_i,
alpha=alpha_i,
Expand Down

0 comments on commit c3218a1

Please sign in to comment.