diff --git a/src/iterative_ensemble_smoother/experimental.py b/src/iterative_ensemble_smoother/experimental.py index 85eab9d6..e3b772f4 100644 --- a/src/iterative_ensemble_smoother/experimental.py +++ b/src/iterative_ensemble_smoother/experimental.py @@ -4,7 +4,7 @@ """ import numbers import warnings -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import numpy.typing as npt @@ -17,43 +17,6 @@ ) -def groupby_indices( - X: npt.NDArray[Any], -) -> Generator[Dict[npt.NDArray[np.double], npt.NDArray[np.int_]], None, None]: - """Yield pairs of (unique_row, indices_of_row). - - Examples - -------- - >>> X = np.array([[1, 0], - ... [1, 0], - ... [1, 1], - ... [1, 1], - ... [1, 0]]) - >>> list(groupby_indices(X)) - [(array([1, 0]), array([0, 1, 4])), (array([1, 1]), array([2, 3]))] - - Another example: - >>> X = np.array([[1, 2, 3], - ... [0, 0, 0], - ... [1, 2, 3], - ... [1, 1, 1], - ... [1, 1, 1], - ... [1, 2, 3]]) - >>> list(groupby_indices(X)) - [(array([0, 0, 0]), array([1])), (array([1, 1, 1]), \ -array([3, 4])), (array([1, 2, 3]), array([0, 2, 5]))] - """ - assert X.ndim == 2 - - # Code was modified from this answer: - # https://stackoverflow.com/questions/30003068/how-to-get-a-list-of-all-indices-of-repeated-elements-in-a-numpy-array - idx_sort = np.lexsort(X.T[::-1, :], axis=0) - sorted_X = X[idx_sort, :] - vals, idx_start = np.unique(sorted_X, return_index=True, axis=0) - res = np.split(idx_sort, idx_start[1:]) - yield from zip(vals, res) - - class AdaptiveESMDA(BaseESMDA): @staticmethod def correlation_threshold(ensemble_size: int) -> float: @@ -294,36 +257,30 @@ def correlation_threshold(ensemble_size: int) -> float: assert cov_YY.ndim == 2, "'cov_YY' must be a 2D array" assert cov_YY.shape == (Y.shape[0], Y.shape[0]) - for unique_row, indices_of_row in groupby_indices(significant_corr_XY): - if verbose: - print( - f" Assimilating {len(indices_of_row)} parameters" - + " with identical correlation thresholds to responses." - ) - print( - " The parameters are significant wrt " - + f"{np.sum(unique_row)} / {len(unique_row)} responses." - ) + # Identify rows with at least one significant correlation. + significant_rows = np.any(significant_corr_XY, axis=1) - # These parameters are not significantly correlated to any responses - if np.all(~unique_row): - continue + # Loop only over rows with significant correlations + for param_num in np.where(significant_rows)[0]: + correlated_responses = significant_corr_XY[param_num] - Y_subset = Y[unique_row, :] + Y_subset = Y[correlated_responses, :] # Compute the masked arrays for these variables - cov_XY_mask = np.ix_(indices_of_row, unique_row) + cov_XY_mask = np.ix_([param_num], correlated_responses) cov_XY_subset = cov_XY[cov_XY_mask] - cov_YY_mask = np.ix_(unique_row, unique_row) + cov_YY_mask = np.ix_(correlated_responses, correlated_responses) cov_YY_subset = cov_YY[cov_YY_mask] # Slice the covariance matrix C_D_subset = ( - self.C_D[unique_row] if self.C_D.ndim == 1 else self.C_D[cov_YY_mask] + self.C_D[correlated_responses] + if self.C_D.ndim == 1 + else self.C_D[cov_YY_mask] ) - D_subset = D[unique_row, :] + D_subset = D[correlated_responses, :] # Compute transition matrix T T = self.compute_cross_covariance_multiplier( @@ -333,7 +290,7 @@ def correlation_threshold(ensemble_size: int) -> float: Y=Y_subset, cov_YY=cov_YY_subset, # Passing cov(Y, Y) avoids re-computation ) - X[indices_of_row, :] += cov_XY_subset @ T + X[[param_num], :] += cov_XY_subset @ T return X diff --git a/tests/test_experimental.py b/tests/test_experimental.py index 39e36ec8..73565ff7 100644 --- a/tests/test_experimental.py +++ b/tests/test_experimental.py @@ -1,4 +1,3 @@ -import functools import time from copy import deepcopy @@ -13,39 +12,9 @@ from iterative_ensemble_smoother.experimental import ( AdaptiveESMDA, ensemble_smoother_update_step_row_scaling, - groupby_indices, ) -@pytest.mark.parametrize("seed", range(25)) -def test_groupby_indices(seed): - rng = np.random.default_rng(seed) - rows = rng.integers(10, 100) - columns = rng.integers(2, 9) - - # Create data matrix - X = rng.integers(0, 10, size=(rows, columns)) - - groups = list(groupby_indices(X)) - indices = [set(idx) for (_, idx) in groups] - - # Verify that every row is included - union_idx = functools.reduce(set.union, indices) - assert union_idx == set(range(rows)) - - # Verify that no duplicate rows occur - intersection_idx = functools.reduce(set.intersection, indices) - assert intersection_idx == set() - - # Verify each entry in the groups - for unique_row, indices_of_row in groups: - # Repeat this unique row the number of times it occurs in X - repeated = np.repeat( - unique_row[np.newaxis, :], repeats=len(indices_of_row), axis=0 - ) - assert np.allclose(X[indices_of_row, :], repeated) - - @pytest.fixture() def linear_problem(request): # Seed the problem using indirect parametrization: