Skip to content

Commit

Permalink
Test that allowing overwrites in ESMDA saves memory (#163)
Browse files Browse the repository at this point in the history
  • Loading branch information
tommyod authored Oct 23, 2023
1 parent df271ef commit dab58e8
Showing 1 changed file with 45 additions and 32 deletions.
77 changes: 45 additions & 32 deletions tests/test_esmda.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def setup(self):
return X_prior, Y_prior, covariance, observations

@pytest.mark.limit_memory("138 MB")
def test_ESMDA_memory_usage_subspace_inversion(self, setup):
def test_ESMDA_memory_usage_subspace_inversion_without_overwrite(self, setup):
# TODO: Currently this is a regression test. Work to improve memory usage.

X_prior, Y_prior, covariance, observations = setup
Expand All @@ -332,46 +332,59 @@ def test_ESMDA_memory_usage_subspace_inversion(self, setup):
for _ in range(esmda.num_assimilations()):
esmda.assimilate(X_prior, Y_prior)

@pytest.mark.parametrize("inversion", ["exact", "subspace"])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("diagonal", [True, False])
def test_that_float_dtypes_are_preserved(self, inversion, dtype, diagonal):
"""If every matrix passed is of a certain dtype, then the output
should also be of the same dtype. 'linalg' does not support float16
nor float128."""
@pytest.mark.limit_memory("129 MB")
def test_ESMDA_memory_usage_subspace_inversion_with_overwrite(self, setup):
# TODO: Currently this is a regression test. Work to improve memory usage.

rng = np.random.default_rng(42)
X_prior, Y_prior, covariance, observations = setup

num_outputs = 20
num_inputs = 10
num_ensemble = 25
# Create ESMDA instance from an integer `alpha` and run it
esmda = ESMDA(covariance, observations, alpha=1, seed=1, inversion="subspace")

# Prior is N(0, 1)
X_prior = rng.normal(size=(num_inputs, num_ensemble))
Y_prior = rng.normal(size=(num_outputs, num_ensemble))
for _ in range(esmda.num_assimilations()):
esmda.assimilate(X_prior, Y_prior, overwrite=True)

# Measurement errors
covariance = np.exp(rng.normal(size=num_outputs))
if not diagonal:
covariance = np.diag(covariance)

# Observations
observations = rng.normal(size=num_outputs, loc=1)
@pytest.mark.parametrize("inversion", ["exact", "subspace"])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("diagonal", [True, False])
def test_that_float_dtypes_are_preserved(inversion, dtype, diagonal):
"""If every matrix passed is of a certain dtype, then the output
should also be of the same dtype. 'linalg' does not support float16
nor float128."""

# Convert types
X_prior = X_prior.astype(dtype)
Y_prior = Y_prior.astype(dtype)
covariance = covariance.astype(dtype)
observations = observations.astype(dtype)
rng = np.random.default_rng(42)

# Create ESMDA instance from an integer `alpha` and run it
esmda = ESMDA(covariance, observations, alpha=1, seed=1, inversion=inversion)
num_outputs = 20
num_inputs = 10
num_ensemble = 25

for _ in range(esmda.num_assimilations()):
X_posterior = esmda.assimilate(X_prior, Y_prior)
# Prior is N(0, 1)
X_prior = rng.normal(size=(num_inputs, num_ensemble))
Y_prior = rng.normal(size=(num_outputs, num_ensemble))

# Measurement errors
covariance = np.exp(rng.normal(size=num_outputs))
if not diagonal:
covariance = np.diag(covariance)

# Observations
observations = rng.normal(size=num_outputs, loc=1)

# Convert types
X_prior = X_prior.astype(dtype)
Y_prior = Y_prior.astype(dtype)
covariance = covariance.astype(dtype)
observations = observations.astype(dtype)

# Create ESMDA instance from an integer `alpha` and run it
esmda = ESMDA(covariance, observations, alpha=1, seed=1, inversion=inversion)

for _ in range(esmda.num_assimilations()):
X_posterior = esmda.assimilate(X_prior, Y_prior)

# Check that dtype of returned array matches input dtype
assert X_posterior.dtype == dtype
# Check that dtype of returned array matches input dtype
assert X_posterior.dtype == dtype


if __name__ == "__main__":
Expand Down

0 comments on commit dab58e8

Please sign in to comment.