diff --git a/src/iterative_ensemble_smoother/experimental.py b/src/iterative_ensemble_smoother/experimental.py index df8886c8..77b68595 100644 --- a/src/iterative_ensemble_smoother/experimental.py +++ b/src/iterative_ensemble_smoother/experimental.py @@ -11,11 +11,14 @@ class RowScaling: + def __init__(self, alpha=1.0): + """Alpha is the strength of the update.""" + assert 0 <= alpha <= 1.0 + self.alpha = alpha + def multiply(self, X, K): - """Takes a matrix X and a matrix K and performs X @ K.""" - # TODO: Not sure why a RowScaling class is needed if - # all that it's used for is matrix multiplication - return X @ K + """Takes a matrix X and a matrix K and performs alpha * X @ K.""" + return X @ (K * self.alpha) def ensemble_smoother_update_step_row_scaling( @@ -96,8 +99,11 @@ def ensemble_smoother_update_step_row_scaling( covariance = np.exp(rng.normal(size=num_observations)) observations = rng.normal(size=num_observations, loc=1) + # Split up X into groups of row_groups = [(0,), (1, 2), (4, 5, 6), tuple(range(7, 100))] - X_with_row_scaling = [(X[idx, :], RowScaling()) for idx in row_groups] + X_with_row_scaling = [ + (X[idx, :], RowScaling(alpha=1 / (i + 1))) for i, idx in enumerate(row_groups) + ] X_before = deepcopy(X_with_row_scaling) X_with_row_scaling_updated = ensemble_smoother_update_step_row_scaling(