Skip to content

Commit

Permalink
Add progress callback to adloc assimilate
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Feb 6, 2024
1 parent 92a0758 commit b11b8d3
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions src/iterative_ensemble_smoother/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
import numbers
import warnings
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Iterable, List, Optional, Tuple, TypeVar, Union

import numpy as np
import numpy.typing as npt
Expand All @@ -16,6 +16,8 @@
empirical_cross_covariance,
)

T = TypeVar("T")


class AdaptiveESMDA(BaseESMDA):
@staticmethod
Expand Down Expand Up @@ -146,6 +148,7 @@ def assimilate(
correlation_threshold: Union[Callable[[int], float], float, None] = None,
cov_YY: Optional[npt.NDArray[np.double]] = None,
verbose: bool = False,
progress_callback: Optional[Callable[[Iterable[T]], Iterable[T]]] = None,
) -> npt.NDArray[np.double]:
"""Assimilate data and return an updated ensemble X_posterior.
Expand Down Expand Up @@ -190,7 +193,14 @@ def assimilate(
method call, potentially saving time and computation.
verbose : bool
Whether to print information.
progress_callback : Callable[[Iterable[T]], Iterable[T]] or None
A callback function that can be used to wrap the iteration over
parameters for progress reporting.
It should accept an iterable as input and return an iterable.
This allows for integration with progress reporting tools like tqdm,
which can provide visual feedback on the progress of the
assimilation process.
If None, no progress reporting is performed.
Returns
-------
X_posterior : np.ndarray
Expand Down Expand Up @@ -229,6 +239,11 @@ def correlation_threshold(ensemble_size: int) -> float:
correlation_threshold
), "`correlation_threshold` should be callable"

if progress_callback is None:

def progress_callback(x):
return x # A simple pass-through function

# Step 1: # Compute cross-correlation between parameters X and responses Y
# Note: let the number of parameters be n and the number of responses be m.
# This step requires both O(mn) computation and O(mn) storage, which is
Expand Down Expand Up @@ -261,7 +276,7 @@ def correlation_threshold(ensemble_size: int) -> float:
significant_rows = np.any(significant_corr_XY, axis=1)

# Loop only over rows with significant correlations
for param_num in np.where(significant_rows)[0]:
for param_num in progress_callback(np.where(significant_rows)[0]):
correlated_responses = significant_corr_XY[param_num]

Y_subset = Y[correlated_responses, :]
Expand Down

0 comments on commit b11b8d3

Please sign in to comment.