Skip to content

Commit

Permalink
propagate truncation parameters down
Browse files Browse the repository at this point in the history
  • Loading branch information
Tommy Odland committed Oct 27, 2023
1 parent 8d6a633 commit 0c8547e
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/iterative_ensemble_smoother/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
class RowScaling:
def multiply(self, X, K):
"""Takes a matrix X and a matrix K and performs X @ K."""
# TODO: Not sure why a RowScaling class is needed if
# all that it's used for is matrix multiplication
return X @ K


Expand Down Expand Up @@ -68,7 +70,9 @@ def ensemble_smoother_update_step_row_scaling(
)

# Create transition matrix - common to all parameters in X
transition_matrix = smoother.compute_transition_matrix(Y=Y, alpha=1, truncation=1.0)
transition_matrix = smoother.compute_transition_matrix(
Y=Y, alpha=1, truncation=truncation
)

# Loop over groups of rows (parameters)
for X, row_scale in X_with_row_scaling:
Expand All @@ -78,6 +82,7 @@ def ensemble_smoother_update_step_row_scaling(


if __name__ == "__main__":
from copy import deepcopy

# Example showing how to use row scaling
num_parameters = 100
Expand All @@ -93,7 +98,7 @@ def ensemble_smoother_update_step_row_scaling(

row_groups = [(0,), (1, 2), (4, 5, 6), tuple(range(7, 100))]
X_with_row_scaling = [(X[idx, :], RowScaling()) for idx in row_groups]
X_with_row_scaling = X_with_row_scaling.copy()
X_before = deepcopy(X_with_row_scaling)

X_with_row_scaling_updated = ensemble_smoother_update_step_row_scaling(
covariance=covariance,
Expand All @@ -102,3 +107,6 @@ def ensemble_smoother_update_step_row_scaling(
Y=Y,
seed=rng,
)

# Check that an update happened
assert not np.allclose(X_before[-1][0], X_with_row_scaling_updated[-1][0])

0 comments on commit 0c8547e

Please sign in to comment.