Skip to content

Commit

Permalink
Add progress callback to adloc assimilate (#212)
Browse files Browse the repository at this point in the history
* Add progress callback to adloc assimilate

* Remove unused verbose argument
  • Loading branch information
dafeda authored Feb 6, 2024
1 parent 92a0758 commit a48841b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
4 changes: 1 addition & 3 deletions docs/source/Adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,7 @@ def g(X):
)

# Assimilate data
X_i = adaptive_smoother.assimilate(
X=X_i, Y=Y_i, D=D_i, alpha=alpha_i, verbose=False
)
X_i = adaptive_smoother.assimilate(X=X_i, Y=Y_i, D=D_i, alpha=alpha_i)


X_adaptive_posterior = np.copy(X_i)
Expand Down
24 changes: 18 additions & 6 deletions src/iterative_ensemble_smoother/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
import numbers
import warnings
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Sequence, Tuple, TypeVar, Union

import numpy as np
import numpy.typing as npt
Expand All @@ -16,6 +16,8 @@
empirical_cross_covariance,
)

T = TypeVar("T")


class AdaptiveESMDA(BaseESMDA):
@staticmethod
Expand Down Expand Up @@ -145,7 +147,7 @@ def assimilate(
alpha: float,
correlation_threshold: Union[Callable[[int], float], float, None] = None,
cov_YY: Optional[npt.NDArray[np.double]] = None,
verbose: bool = False,
progress_callback: Optional[Callable[[Sequence[T]], Sequence[T]]] = None,
) -> npt.NDArray[np.double]:
"""Assimilate data and return an updated ensemble X_posterior.
Expand Down Expand Up @@ -188,9 +190,14 @@ def assimilate(
A 2D array of shape (num_observations, num_observations) with the
empirical covariance of Y. If passed, this is not computed in the
method call, potentially saving time and computation.
verbose : bool
Whether to print information.
progress_callback : Callable[[Sequence[T]], Sequence[T]] or None
A callback function that can be used to wrap the iteration over
parameters for progress reporting.
It should accept an iterable as input and return an iterable.
This allows for integration with progress reporting tools like tqdm,
which can provide visual feedback on the progress of the
assimilation process.
If None, no progress reporting is performed.
Returns
-------
X_posterior : np.ndarray
Expand Down Expand Up @@ -229,6 +236,11 @@ def correlation_threshold(ensemble_size: int) -> float:
correlation_threshold
), "`correlation_threshold` should be callable"

if progress_callback is None:

def progress_callback(x):
return x # A simple pass-through function

# Step 1: # Compute cross-correlation between parameters X and responses Y
# Note: let the number of parameters be n and the number of responses be m.
# This step requires both O(mn) computation and O(mn) storage, which is
Expand Down Expand Up @@ -261,7 +273,7 @@ def correlation_threshold(ensemble_size: int) -> float:
significant_rows = np.any(significant_corr_XY, axis=1)

# Loop only over rows with significant correlations
for param_num in np.where(significant_rows)[0]:
for param_num in progress_callback(np.where(significant_rows)[0]):
correlated_responses = significant_corr_XY[param_num]

Y_subset = Y[correlated_responses, :]
Expand Down
1 change: 0 additions & 1 deletion tests/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,6 @@ def zero_correlation_threshold(ensemble_size):
# Pass a function that always returns zero,
# no matter what the ensemble size is
correlation_threshold=zero_correlation_threshold,
verbose=True,
)

print()
Expand Down

0 comments on commit a48841b

Please sign in to comment.