Skip to content

Commit

Permalink
fix bug when using full covariance
Browse files Browse the repository at this point in the history
  • Loading branch information
Tommy Odland committed Dec 6, 2023
1 parent c8a9303 commit 59992d3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 37 deletions.
47 changes: 11 additions & 36 deletions src/iterative_ensemble_smoother/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,41 +17,6 @@
)


def groupby_indices_feda(X):
"""
Examples
--------
With a boolean input matrix:
>>> X = np.array([[0, 0, 0],
... [0, 0, 0],
... [1, 0, 0],
... [1, 0, 0],
... [1, 1, 1],
... [1, 1, 1]], dtype=bool)
>>> list(groupby_indices_feda(X))
[(array([ True, False, False]), array([2, 3])), \
(array([ True, True, True]), array([4, 5]))]
"""

# Unique rows
param_correlation_sets = np.unique(X, axis=0)

# Drop the correlation set that does not correlate to any responses.
row_with_all_false = np.all(~param_correlation_sets, axis=1)
param_correlation_sets = param_correlation_sets[~row_with_all_false]

for param_correlation_set in param_correlation_sets:
# Find the rows matching the parameter group
matching_rows = np.all(X == param_correlation_set, axis=1)

# Get the indices of the matching rows
row_indices = np.where(matching_rows)[0]

yield (param_correlation_set, row_indices)


def groupby_indices(X):
"""Yield pairs of (unique_row, indices_of_row).
Expand Down Expand Up @@ -91,8 +56,13 @@ def groupby_indices(X):

# Code was modified from this answer:
# https://stackoverflow.com/questions/30003068/how-to-get-a-list-of-all-indices-of-repeated-elements-in-a-numpy-array

# ~15% of the total time is spent in np.lexsort on a large problem
idx_sort = np.lexsort(X.T[::-1, :], axis=0)
sorted_X = X[idx_sort, :]

# ~85% of the total time is spent in np.unique on a large problem,
# but np.unique is roughly linear in complexity, so little can be done.
vals, idx_start = np.unique(sorted_X, return_index=True, axis=0)
res = np.split(idx_sort, idx_start[1:])
yield from zip(vals, res)
Expand Down Expand Up @@ -326,7 +296,12 @@ def correlation_threshold(ensemble_size):
cov_YY_mask = np.ix_(unique_row, unique_row)
cov_YY_subset = cov_YY[cov_YY_mask]

C_D_subset = self.C_D[unique_row]
# Slice the covariance matrix
if self.C_D.ndim == 1:
C_D_subset = self.C_D[unique_row]
else:
C_D_subset = self.C_D[np.ix_(unique_row, unique_row)]

D_subset = D[unique_row, :]

# Compute transition matrix T
Expand Down
5 changes: 4 additions & 1 deletion tests/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,18 @@ def test_that_adaptive_localization_with_cutoff_0_equals_standard_ESMDA_update(
@pytest.mark.parametrize(
"cutoffs", [(0, 1e-3), (0.1, 0.2), (0.5, 0.5 + 1e-12), (0.9, 1), (1 - 1e-3, 1)]
)
@pytest.mark.parametrize("full_covariance", [True, False])
def test_that_posterior_generalized_variance_increases_in_cutoff(
self, linear_problem, cutoffs
self, linear_problem, cutoffs, full_covariance
):
"""This property only holds in the limit as the number of
ensemble members goes to infinity. As the number of ensemble
members decrease, this test starts to fail more often."""

# Create a problem with g(x) = A @ x
X, g, observations, covariance, rng = linear_problem
if full_covariance:
covariance = np.diag(covariance)

# =============================================================================
# SETUP ESMDA FOR LOCALIZATION AND SOLVE PROBLEM
Expand Down

0 comments on commit 59992d3

Please sign in to comment.