diff --git a/src/iterative_ensemble_smoother/experimental.py b/src/iterative_ensemble_smoother/experimental.py index e5764a12..1d86c492 100644 --- a/src/iterative_ensemble_smoother/experimental.py +++ b/src/iterative_ensemble_smoother/experimental.py @@ -3,11 +3,11 @@ features of iterative_ensemble_smoother """ import numbers -from typing import List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union import numpy as np import numpy.typing as npt -import scipy as sp +import scipy as sp # type: ignore from iterative_ensemble_smoother import ESMDA from iterative_ensemble_smoother.esmda import BaseESMDA @@ -16,7 +16,9 @@ ) -def groupby_indices(X): +def groupby_indices( + X: npt.NDArray[Any], +) -> Generator[Dict[npt.NDArray[np.double], npt.NDArray[np.int_]], None, None]: """Yield pairs of (unique_row, indices_of_row). Examples @@ -74,7 +76,7 @@ def correlation_threshold(ensemble_size: int) -> float: >>> AdaptiveESMDA.correlation_threshold(36) 0.5 """ - return min(1, max(0, 3 / np.sqrt(ensemble_size))) + return float(min(1, max(0, 3 / np.sqrt(ensemble_size)))) @staticmethod def compute_cross_covariance_multiplier( @@ -137,9 +139,14 @@ def compute_cross_covariance_multiplier( # C_D is an array, so add it to the diagonal without forming diag(C_D) np.fill_diagonal(C_DD, C_DD.diagonal() + alpha * C_D) - return sp.linalg.solve(C_DD, D - Y, **solver_kwargs) + return sp.linalg.solve(C_DD, D - Y, **solver_kwargs) # type: ignore - def _correlation_matrix(self, cov_XY, X, Y): + def _correlation_matrix( + self, + cov_XY: npt.NDArray[np.double], + X: npt.NDArray[np.double], + Y: npt.NDArray[np.double], + ) -> npt.NDArray[np.double]: """Compute a correlation matrix given a covariance matrix.""" assert cov_XY.shape == (X.shape[0], Y.shape[0]) @@ -147,14 +154,25 @@ def _correlation_matrix(self, cov_XY, X, Y): stds_X = np.std(X, axis=1, ddof=1) # Compute the correlation matrix from the covariance matrix - corr_XY = (cov_XY / stds_X[:, np.newaxis]) / stds_Y[np.newaxis, :] + corr_XY: npt.NDArray[np.double] = (cov_XY / stds_X[:, np.newaxis]) / stds_Y[ + np.newaxis, : + ] # Perform checks assert corr_XY.max() <= 1 assert corr_XY.min() >= -1 return corr_XY - def assimilate(self, X, Y, D, alpha, correlation_threshold=None, verbose=False): + def assimilate( + self, + *, + X: npt.NDArray[np.double], + Y: npt.NDArray[np.double], + D: npt.NDArray[np.double], + alpha: float, + correlation_threshold: Union[Callable[[int], float], float, None] = None, + verbose: bool = False, + ) -> npt.NDArray[np.double]: """Assimilate data and return an updated ensemble X_posterior. X_posterior = smoother.assimilate(X, Y, D, alpha) @@ -213,9 +231,9 @@ def assimilate(self, X, Y, D, alpha, correlation_threshold=None, verbose=False): # Create `correlation_threshold` if the argument is a float if is_float: - corr_threshold = correlation_threshold + corr_threshold: float = correlation_threshold # type: ignore - def correlation_threshold(ensemble_size): + def correlation_threshold(ensemble_size: int) -> float: return corr_threshold # Default correlation threshold function @@ -238,14 +256,14 @@ def correlation_threshold(ensemble_size): # Determine which elements in the cross covariance matrix that will # be set to zero - threshold = correlation_threshold(ensemble_size=X.shape[1]) + threshold = correlation_threshold(X.shape[1]) significant_corr_XY = np.abs(corr_XY) > threshold # Pre-compute the covariance cov(Y, Y) here, and index on it later cov_YY = empirical_cross_covariance(Y, Y) # TODO: memory could be saved by overwriting the input X - X_out = np.copy(X) + X_out: npt.NDArray[np.double] = np.copy(X) for (unique_row, indices_of_row) in groupby_indices(significant_corr_XY): if verbose: