Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add correlation callback to adaptive loc #216

Merged
merged 2 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ text = "GPL-3.0"
[project.optional-dependencies]
doc = [
"sphinx",
"docutils<0.21",
"pydata_sphinx_theme",
"jupyter_sphinx",
"sphinxcontrib.bibtex",
Expand Down
9 changes: 9 additions & 0 deletions src/iterative_ensemble_smoother/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def assimilate(
correlation_threshold: Union[Callable[[int], float], float, None] = None,
cov_YY: Optional[npt.NDArray[np.double]] = None,
progress_callback: Optional[Callable[[Sequence[T]], Sequence[T]]] = None,
correlation_callback: Optional[Callable[[npt.NDArray[np.double]], None]] = None,
) -> npt.NDArray[np.double]:
"""Assimilate data and return an updated ensemble X_posterior.

Expand Down Expand Up @@ -199,6 +200,11 @@ def assimilate(
which can provide visual feedback on the progress of the
assimilation process.
If None, no progress reporting is performed.
correlation_callback : Optional[Callable]
A callback function that is called with the correlation matrix (2D array)
as its argument after the correlation matrix computation is complete.
The callback should handle or process the correlation matrix, such as
saving or logging it. The callback should not return any value.
Returns
-------
X_posterior : np.ndarray
Expand Down Expand Up @@ -305,6 +311,9 @@ def progress_callback(x):
)
X[[param_num], :] += cov_XY_subset @ T

if correlation_callback is not None:
corr_XY = self._cov_to_corr_inplace(cov_XY, stds_X, stds_Y)
correlation_callback(corr_XY[significant_rows])
return X


Expand Down
22 changes: 21 additions & 1 deletion tests/test_experimental.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import time
from copy import deepcopy

Expand Down Expand Up @@ -60,6 +61,15 @@ def test_that_adaptive_localization_with_cutoff_1_equals_ensemble_prior(
covariance=covariance, observations=observations, seed=1
)

def correlation_callback(corr_matrix):
# A correlation threshold of 1 means that no
# correlations are deemed significant.
# Therefore, the cross-correlation matrix must
# not include any parameter-response pairs.
print(corr_matrix)
assert corr_matrix.shape[0] == 0
assert corr_matrix.shape[1] == len(observations)

X_i = np.copy(X)
alpha = normalize_alpha(np.ones(5))
for alpha_i in alpha:
Expand All @@ -78,6 +88,7 @@ def test_that_adaptive_localization_with_cutoff_1_equals_ensemble_prior(
D=D_i,
alpha=alpha_i,
correlation_threshold=1,
correlation_callback=correlation_callback,
)

assert np.allclose(X, X_i)
Expand All @@ -92,7 +103,7 @@ def test_that_adaptive_localization_with_cutoff_0_equals_standard_ESMDA_update(
self, linear_problem
):
# Create a problem with g(x) = A @ x
X, g, observations, covariance, rng = linear_problem
X, g, observations, covariance, _ = linear_problem

# =============================================================================
# SETUP ESMDA FOR LOCALIZATION AND SOLVE PROBLEM
Expand All @@ -103,6 +114,14 @@ def test_that_adaptive_localization_with_cutoff_0_equals_standard_ESMDA_update(
covariance=covariance, observations=observations, seed=1
)

def correlation_callback(corr_matrix):
# A correlation threshold of 0 means that all
# correlations are deemed significant.
# Therefore, the cross-correlation matrix must
# include all parameter-resposne pairs.
assert corr_matrix.shape[0] == X.shape[0]
assert corr_matrix.shape[1] == len(observations)

X_i = np.copy(X)
for _, alpha_i in enumerate(alpha, 1):
# Run forward model
Expand All @@ -121,6 +140,7 @@ def test_that_adaptive_localization_with_cutoff_0_equals_standard_ESMDA_update(
overwrite=True,
alpha=alpha_i,
correlation_threshold=0,
correlation_callback=functools.partial(correlation_callback),
)

# =============================================================================
Expand Down
Loading