From 3efe67c78a992ab88e6f1bd49fb4f0948e645072 Mon Sep 17 00:00:00 2001 From: Tommy Odland Date: Fri, 27 Oct 2023 07:24:17 +0200 Subject: [PATCH] update example to use alpha as update strength --- src/iterative_ensemble_smoother/experimental.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/iterative_ensemble_smoother/experimental.py b/src/iterative_ensemble_smoother/experimental.py index df8886c8..77b68595 100644 --- a/src/iterative_ensemble_smoother/experimental.py +++ b/src/iterative_ensemble_smoother/experimental.py @@ -11,11 +11,14 @@ class RowScaling: + def __init__(self, alpha=1.0): + """Alpha is the strength of the update.""" + assert 0 <= alpha <= 1.0 + self.alpha = alpha + def multiply(self, X, K): - """Takes a matrix X and a matrix K and performs X @ K.""" - # TODO: Not sure why a RowScaling class is needed if - # all that it's used for is matrix multiplication - return X @ K + """Takes a matrix X and a matrix K and performs alpha * X @ K.""" + return X @ (K * self.alpha) def ensemble_smoother_update_step_row_scaling( @@ -96,8 +99,11 @@ def ensemble_smoother_update_step_row_scaling( covariance = np.exp(rng.normal(size=num_observations)) observations = rng.normal(size=num_observations, loc=1) + # Split up X into groups of row_groups = [(0,), (1, 2), (4, 5, 6), tuple(range(7, 100))] - X_with_row_scaling = [(X[idx, :], RowScaling()) for idx in row_groups] + X_with_row_scaling = [ + (X[idx, :], RowScaling(alpha=1 / (i + 1))) for i, idx in enumerate(row_groups) + ] X_before = deepcopy(X_with_row_scaling) X_with_row_scaling_updated = ensemble_smoother_update_step_row_scaling(