Skip to content

Commit

Permalink
Types for AdaptiveESMDA (#191)
Browse files Browse the repository at this point in the history
* types for AdaptiveESMDA

---------

Co-authored-by: Tommy Odland <tommy.odland>
  • Loading branch information
tommyod authored Dec 6, 2023
1 parent b9da2de commit 4e2fb9b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/iterative_ensemble_smoother/esmda_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def inversion_subspace(
C_D: npt.NDArray[np.double],
D: npt.NDArray[np.double],
Y: npt.NDArray[np.double],
X: npt.NDArray[np.double],
X: Optional[npt.NDArray[np.double]],
truncation: float = 1.0,
return_T: bool = False,
) -> npt.NDArray[np.double]:
Expand Down
48 changes: 32 additions & 16 deletions src/iterative_ensemble_smoother/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
"""
import numbers
import warnings
from typing import List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union

import numpy as np
import numpy.typing as npt
import scipy as sp
import scipy as sp # type: ignore

from iterative_ensemble_smoother import ESMDA
from iterative_ensemble_smoother.esmda import BaseESMDA
Expand All @@ -17,7 +17,9 @@
)


def groupby_indices(X):
def groupby_indices(
X: npt.NDArray[Any],
) -> Generator[Dict[npt.NDArray[np.double], npt.NDArray[np.int_]], None, None]:
"""Yield pairs of (unique_row, indices_of_row).
Examples
Expand Down Expand Up @@ -65,17 +67,15 @@ def correlation_threshold(ensemble_size: int) -> float:
Examples
--------
>>> AdaptiveESMDA.correlation_threshold(0)
1
>>> AdaptiveESMDA.correlation_threshold(4)
1
1.0
>>> AdaptiveESMDA.correlation_threshold(9)
1
1.0
>>> AdaptiveESMDA.correlation_threshold(16)
0.75
>>> AdaptiveESMDA.correlation_threshold(36)
0.5
"""
return min(1, max(0, 3 / np.sqrt(ensemble_size)))
return float(min(1, max(0, 3 / np.sqrt(ensemble_size))))

@staticmethod
def compute_cross_covariance_multiplier(
Expand Down Expand Up @@ -138,17 +138,24 @@ def compute_cross_covariance_multiplier(
# C_D is an array, so add it to the diagonal without forming diag(C_D)
np.fill_diagonal(C_DD, C_DD.diagonal() + alpha * C_D)

return sp.linalg.solve(C_DD, D - Y, **solver_kwargs)
return sp.linalg.solve(C_DD, D - Y, **solver_kwargs) # type: ignore

def _correlation_matrix(self, cov_XY, X, Y):
def _correlation_matrix(
self,
cov_XY: npt.NDArray[np.double],
X: npt.NDArray[np.double],
Y: npt.NDArray[np.double],
) -> npt.NDArray[np.double]:
"""Compute a correlation matrix given a covariance matrix."""
assert cov_XY.shape == (X.shape[0], Y.shape[0])

stds_Y = np.std(Y, axis=1, ddof=1)
stds_X = np.std(X, axis=1, ddof=1)

# Compute the correlation matrix from the covariance matrix
corr_XY = (cov_XY / stds_X[:, np.newaxis]) / stds_Y[np.newaxis, :]
corr_XY: npt.NDArray[np.double] = (cov_XY / stds_X[:, np.newaxis]) / stds_Y[
np.newaxis, :
]

# Perform checks. There appears to be occasional numerical issues in
# the equation. With 2 ensemble members, we get e.g. a max value of
Expand All @@ -162,7 +169,16 @@ def _correlation_matrix(self, cov_XY, X, Y):
corr_XY = np.clip(corr_XY, a_min=-1, a_max=1)
return corr_XY

def assimilate(self, X, Y, D, alpha, correlation_threshold=None, verbose=False):
def assimilate(
self,
*,
X: npt.NDArray[np.double],
Y: npt.NDArray[np.double],
D: npt.NDArray[np.double],
alpha: float,
correlation_threshold: Union[Callable[[int], float], float, None] = None,
verbose: bool = False,
) -> npt.NDArray[np.double]:
"""Assimilate data and return an updated ensemble X_posterior.
X_posterior = smoother.assimilate(X, Y, D, alpha)
Expand Down Expand Up @@ -221,9 +237,9 @@ def assimilate(self, X, Y, D, alpha, correlation_threshold=None, verbose=False):

# Create `correlation_threshold` if the argument is a float
if is_float:
corr_threshold = correlation_threshold
corr_threshold: float = correlation_threshold # type: ignore

def correlation_threshold(ensemble_size):
def correlation_threshold(ensemble_size: int) -> float:
return corr_threshold

# Default correlation threshold function
Expand All @@ -246,14 +262,14 @@ def correlation_threshold(ensemble_size):

# Determine which elements in the cross covariance matrix that will
# be set to zero
threshold = correlation_threshold(ensemble_size=X.shape[1])
threshold = correlation_threshold(X.shape[1])
significant_corr_XY = np.abs(corr_XY) > threshold

# Pre-compute the covariance cov(Y, Y) here, and index on it later
cov_YY = empirical_cross_covariance(Y, Y)

# TODO: memory could be saved by overwriting the input X
X_out = np.copy(X)
X_out: npt.NDArray[np.double] = np.copy(X)
for (unique_row, indices_of_row) in groupby_indices(significant_corr_XY):

if verbose:
Expand Down

0 comments on commit 4e2fb9b

Please sign in to comment.