Skip to content

Commit

Permalink
more descriptive method names
Browse files Browse the repository at this point in the history
  • Loading branch information
tommyod committed Oct 19, 2023
1 parent 5a77c23 commit a1c26fc
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
13 changes: 9 additions & 4 deletions src/iterative_ensemble_smoother/esmda.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def assimilate(
# a zero cented normal means that y := L @ z, where z ~ norm(0, 1)
# Therefore, scaling C_D by alpha is equivalent to scaling L with sqrt(alpha)
size = (num_outputs, num_ensemble)
D = self.get_D(size=size, alpha=self.alpha[self.iteration])
D = self.perturb_observations(size=size, alpha=self.alpha[self.iteration])
assert D.shape == (num_outputs, num_ensemble)

# Line 2 (c) in the description of ES-MDA in the 2013 Emerick paper
Expand All @@ -245,7 +245,7 @@ def assimilate(
self.iteration += 1
return X

def get_K(
def compute_transition_matrix(
self,
Y: npt.NDArray[np.double],
*,
Expand Down Expand Up @@ -279,7 +279,10 @@ def get_K(
It has shape (num_ensemble_members, num_ensemble_members).
"""

D = self.get_D(size=Y.shape, alpha=alpha)
# Recall the update equation
# X += C_MD @ (C_DD + alpha * C_D)^(-1) @ (D - Y)

D = self.perturb_observations(size=Y.shape, alpha=alpha)
inversion_func = self._inversion_methods[self.inversion]
return inversion_func(
alpha=alpha,
Expand All @@ -291,7 +294,9 @@ def get_K(
return_K=True, # Ensures that we don't need X
)

def get_D(self, *, size: Tuple[int, int], alpha: float) -> npt.NDArray[np.double]:
def perturb_observations(
self, *, size: Tuple[int, int], alpha: float
) -> npt.NDArray[np.double]:
"""Create a matrix D with perturbed observations.
In the Emerick (2013) paper, the matrix D is defined in section 6.
Expand Down
3 changes: 0 additions & 3 deletions src/iterative_ensemble_smoother/esmda_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,6 @@ def inversion_subspace(
[0., 0., 0.]])
"""
print(
{"alpha": alpha, "C_D": C_D, "D": D, "Y": Y, "X": X, "truncation": truncation}
)

# N_n is the number of observations
# N_e is the number of members in the ensemble
Expand Down
2 changes: 1 addition & 1 deletion tests/test_esmda.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def g(X):
)
X = np.copy(X_prior)
for alpha_i in smoother.alpha:
K = smoother.get_K(Y=g(X), alpha=alpha_i)
K = smoother.compute_transition_matrix(Y=g(X), alpha=alpha_i)

# TODO: Why is this equivalent? ...
X_centered = X - np.mean(X, axis=1, keepdims=True)
Expand Down

0 comments on commit a1c26fc

Please sign in to comment.