Skip to content

Commit

Permalink
types for AdaptiveESMDA
Browse files Browse the repository at this point in the history
  • Loading branch information
Tommy Odland committed Nov 30, 2023
1 parent e181c44 commit f0b3819
Showing 1 changed file with 30 additions and 12 deletions.
42 changes: 30 additions & 12 deletions src/iterative_ensemble_smoother/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
features of iterative_ensemble_smoother
"""
import numbers
from typing import List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union

import numpy as np
import numpy.typing as npt
import scipy as sp
import scipy as sp # type: ignore

from iterative_ensemble_smoother import ESMDA
from iterative_ensemble_smoother.esmda import BaseESMDA
Expand All @@ -16,7 +16,9 @@
)


def groupby_indices(X):
def groupby_indices(
X: npt.NDArray[Any],
) -> Generator[Dict[npt.NDArray[np.double], npt.NDArray[np.int_]], None, None]:
"""Yield pairs of (unique_row, indices_of_row).
Examples
Expand Down Expand Up @@ -74,7 +76,7 @@ def correlation_threshold(ensemble_size: int) -> float:
>>> AdaptiveESMDA.correlation_threshold(36)
0.5
"""
return min(1, max(0, 3 / np.sqrt(ensemble_size)))
return float(min(1, max(0, 3 / np.sqrt(ensemble_size))))

@staticmethod
def compute_cross_covariance_multiplier(
Expand Down Expand Up @@ -137,24 +139,40 @@ def compute_cross_covariance_multiplier(
# C_D is an array, so add it to the diagonal without forming diag(C_D)
np.fill_diagonal(C_DD, C_DD.diagonal() + alpha * C_D)

return sp.linalg.solve(C_DD, D - Y, **solver_kwargs)
return sp.linalg.solve(C_DD, D - Y, **solver_kwargs) # type: ignore

def _correlation_matrix(self, cov_XY, X, Y):
def _correlation_matrix(
self,
cov_XY: npt.NDArray[np.double],
X: npt.NDArray[np.double],
Y: npt.NDArray[np.double],
) -> npt.NDArray[np.double]:
"""Compute a correlation matrix given a covariance matrix."""
assert cov_XY.shape == (X.shape[0], Y.shape[0])

stds_Y = np.std(Y, axis=1, ddof=1)
stds_X = np.std(X, axis=1, ddof=1)

# Compute the correlation matrix from the covariance matrix
corr_XY = (cov_XY / stds_X[:, np.newaxis]) / stds_Y[np.newaxis, :]
corr_XY: npt.NDArray[np.double] = (cov_XY / stds_X[:, np.newaxis]) / stds_Y[
np.newaxis, :
]

# Perform checks
assert corr_XY.max() <= 1
assert corr_XY.min() >= -1
return corr_XY

def assimilate(self, X, Y, D, alpha, correlation_threshold=None, verbose=False):
def assimilate(
self,
*,
X: npt.NDArray[np.double],
Y: npt.NDArray[np.double],
D: npt.NDArray[np.double],
alpha: float,
correlation_threshold: Union[Callable[[int], float], float, None] = None,
verbose: bool = False,
) -> npt.NDArray[np.double]:
"""Assimilate data and return an updated ensemble X_posterior.
X_posterior = smoother.assimilate(X, Y, D, alpha)
Expand Down Expand Up @@ -213,9 +231,9 @@ def assimilate(self, X, Y, D, alpha, correlation_threshold=None, verbose=False):

# Create `correlation_threshold` if the argument is a float
if is_float:
corr_threshold = correlation_threshold
corr_threshold: float = correlation_threshold # type: ignore

def correlation_threshold(ensemble_size):
def correlation_threshold(ensemble_size: int) -> float:
return corr_threshold

# Default correlation threshold function
Expand All @@ -238,14 +256,14 @@ def correlation_threshold(ensemble_size):

# Determine which elements in the cross covariance matrix that will
# be set to zero
threshold = correlation_threshold(ensemble_size=X.shape[1])
threshold = correlation_threshold(X.shape[1])
significant_corr_XY = np.abs(corr_XY) > threshold

# Pre-compute the covariance cov(Y, Y) here, and index on it later
cov_YY = empirical_cross_covariance(Y, Y)

# TODO: memory could be saved by overwriting the input X
X_out = np.copy(X)
X_out: npt.NDArray[np.double] = np.copy(X)
for (unique_row, indices_of_row) in groupby_indices(significant_corr_XY):

if verbose:
Expand Down

0 comments on commit f0b3819

Please sign in to comment.