Skip to content

Commit

Permalink
Implement adaptive localization
Browse files Browse the repository at this point in the history
Add option of running adaptive localization that can simply
be turned on and does not need any user input.
Only parameters that are significantly correlated to responses
will be updated.
Default value of what constitutes significant correlation is calculated
based on theory, but can be set by the user.
  • Loading branch information
dafeda committed Oct 17, 2023
1 parent 4180da7 commit d80f169
Show file tree
Hide file tree
Showing 9 changed files with 329 additions and 43 deletions.
128 changes: 110 additions & 18 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from iterative_ensemble_smoother.experimental import (
ensemble_smoother_update_step_row_scaling,
)
from tqdm import tqdm

from ert.config import Field, GenKwConfig, SurfaceConfig
from ert.realization_state import RealizationState
Expand Down Expand Up @@ -370,6 +371,12 @@ def _load_observations_and_responses(
)


def _split_by_batchsize(
arr: npt.NDArray[np.int_], batch_size: int
) -> List[npt.NDArray[np.int_]]:
return np.array_split(arr, int((arr.shape[0] / batch_size)) + 1)


def analysis_ES(
updatestep: UpdateConfiguration,
obs: EnkfObs,
Expand Down Expand Up @@ -417,21 +424,17 @@ def analysis_ES(

# pylint: disable=unsupported-assignment-operation
smoother_snapshot.update_step_snapshots[update_step.name] = update_snapshot
if len(observation_values) == 0:

num_obs = len(observation_values)
if num_obs == 0:
raise ErtAnalysisError(
f"No active observations for update step: {update_step.name}."
)
noise = rng.standard_normal(size=(len(observation_values), S.shape[1]))

smoother = ies.ES()
smoother.fit(
S,
observation_errors,
observation_values,
noise=noise,
truncation=module.get_truncation(),
inversion=ies.InversionType(module.inversion),
param_ensemble=param_ensemble,
)
truncation = module.get_truncation()
noise = rng.standard_normal(size=(num_obs, ensemble_size))

for param_group in update_step.parameters:
source: Union[EnsembleReader, EnsembleAccessor]
if target_fs.has_parameter_group(param_group.name):
Expand All @@ -441,15 +444,92 @@ def analysis_ES(
temp_storage = _create_temporary_parameter_storage(
source, iens_active_index, param_group.name
)
progress_callback(Progress(Task("Updating data", 2, 3), None))
if active_indices := param_group.index_list:
temp_storage[param_group.name][active_indices, :] = smoother.update(
temp_storage[param_group.name][active_indices, :]

if module.localization():
Y_prime = S - S.mean(axis=1, keepdims=True)
C_YY = Y_prime @ Y_prime.T / (ensemble_size - 1)
Sigma_Y = np.diag(np.sqrt(np.diag(C_YY)))
batch_size: int = 1000
correlation_threshold = module.localization_correlation_threshold(
ensemble_size
)
# for parameter in update_step.parameters:
num_params = temp_storage[param_group.name].shape[0]

print(
(
f"Running localization on {num_params} parameters,",
f"{num_obs} responses and {ensemble_size} realizations...",
)
)
batches = _split_by_batchsize(np.arange(0, num_params), batch_size)
for param_batch_idx in tqdm(batches):
X_local = temp_storage[param_group.name][param_batch_idx, :]
A = X_local - X_local.mean(axis=1, keepdims=True)
C_AA = A @ A.T / (ensemble_size - 1)

# State-measurement covariance matrix
C_AY = A @ Y_prime.T / (ensemble_size - 1)
Sigma_A = np.diag(np.sqrt(np.diag(C_AA)))

# State-measurement correlation matrix
c_AY = np.abs(
np.linalg.inv(Sigma_A) @ C_AY @ np.linalg.inv(Sigma_Y)
)
c_bool = c_AY > correlation_threshold
# Some parameters might be significantly correlated
# to the exact same responses,
# making up what we call a `parameter group``.
# We want to call the update only once per such parameter group
# to speed up computation.
param_groups = np.unique(c_bool, axis=0)

# Drop the parameter group that does not correlate to any responses.
row_with_all_false = np.all(~param_groups, axis=1)

Check failure on line 488 in src/ert/analysis/_es_update.py

View workflow job for this annotation

GitHub Actions / type-checking (3.11)

Unsupported operand type for ~ ("list[str]")

Check failure on line 488 in src/ert/analysis/_es_update.py

View workflow job for this annotation

GitHub Actions / type-checking (3.11)

Unsupported operand type for ~ ("list[str]")
param_groups = param_groups[~row_with_all_false]

for grp in param_groups:
# Find the rows matching the parameter group
matching_rows = np.all(c_bool == grp, axis=1)
# Get the indices of the matching rows
row_indices = np.where(matching_rows)[0]
X_chunk = temp_storage[param_group.name][param_batch_idx, :][
row_indices, :
]
S_chunk = S[grp, :]
observation_errors_loc = observation_errors[grp]
observation_values_loc = observation_values[grp]
smoother.fit(
S_chunk,
observation_errors_loc,
observation_values_loc,
noise=noise[grp],

Check failure on line 506 in src/ert/analysis/_es_update.py

View workflow job for this annotation

GitHub Actions / type-checking (3.11)

No overload variant of "__getitem__" of "ndarray" matches argument type "str"

Check failure on line 506 in src/ert/analysis/_es_update.py

View workflow job for this annotation

GitHub Actions / type-checking (3.11)

No overload variant of "__getitem__" of "ndarray" matches argument type "str"
truncation=truncation,
inversion=ies.InversionType(module.inversion),
param_ensemble=param_ensemble,
)
temp_storage[param_group.name][
param_batch_idx[row_indices], :
] = smoother.update(X_chunk)
else:
temp_storage[param_group.name] = smoother.update(
temp_storage[param_group.name]
smoother.fit(
S,
observation_errors,
observation_values,
noise=noise,
truncation=truncation,
inversion=ies.InversionType(module.inversion),
param_ensemble=param_ensemble,
)
if active_indices := param_group.index_list:
temp_storage[param_group.name][active_indices, :] = smoother.update(
temp_storage[param_group.name][active_indices, :]
)
else:
temp_storage[param_group.name] = smoother.update(
temp_storage[param_group.name]
)

if params_with_row_scaling := _get_params_with_row_scaling(
temp_storage, update_step.row_scaling_parameters
):
Expand All @@ -465,7 +545,19 @@ def analysis_ES(
for row_scaling_parameter, (A, _) in zip(
update_step.row_scaling_parameters, params_with_row_scaling
):
_save_to_temp_storage(temp_storage, [row_scaling_parameter], A)
params_with_row_scaling = ensemble_smoother_update_step_row_scaling(
S,
params_with_row_scaling,
observation_errors,
observation_values,
noise,
module.get_truncation(),
ies.InversionType(module.inversion),
)
for row_scaling_parameter, (A, _) in zip(
update_step.row_scaling_parameters, params_with_row_scaling
):
_save_to_temp_storage(temp_storage, [row_scaling_parameter], A)

progress_callback(Progress(Task("Storing data", 3, 3), None))
_save_temp_storage_to_disk(target_fs, temp_storage, iens_active_index)
Expand Down
124 changes: 100 additions & 24 deletions src/ert/config/analysis_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import math
import sys
from typing import TYPE_CHECKING, Dict, List, Type, TypedDict, Union

Expand Down Expand Up @@ -33,6 +34,23 @@ class VariableInfo(TypedDict):
DEFAULT_IES_DEC_STEPLENGTH = 2.50
DEFAULT_ENKF_TRUNCATION = 0.98
DEFAULT_IES_INVERSION = 0
DEFAULT_LOCALIZATION = False
# Default threshold is a function of ensemble size which is not available here.
DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD = -1


def correlation_threshold(ensemble_size: int, user_defined_threshold: float) -> float:
"""Decides whether or not to use user-defined or default threshold.
Default threshold taken from luo2022,
Continuous Hyper-parameter OPtimization (CHOP) in an ensemble Kalman filter
Section 2.3 - Localization in the CHOP problem
"""
default_threshold = 3 / math.sqrt(ensemble_size)
if user_defined_threshold == -1:
return default_threshold

return user_defined_threshold


class AnalysisMode(StrEnum):
Expand All @@ -58,6 +76,22 @@ def get_mode_variables(mode: AnalysisMode) -> Dict[str, "VariableInfo"]:
"step": 0.01,
"labelname": "Singular value truncation",
},
"LOCALIZATION": {
"type": bool,
"min": 0.0,
"value": DEFAULT_LOCALIZATION,
"max": 1.0,
"step": 1.0,
"labelname": "Adaptive localization",
},
"LOCALIZATION_CORRELATION_THRESHOLD": {
"type": float,
"min": 0.0,
"value": DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD,
"max": 1.0,
"step": 0.1,
"labelname": "Adaptive localization correlation threshold",
},
}
ies_variables: Dict[str, "VariableInfo"] = {
"IES_MAX_STEPLENGTH": {
Expand Down Expand Up @@ -169,31 +203,47 @@ def set_var(self, var_name: str, value: Union[float, int, bool, str]) -> None:
self.handle_special_key_set(var_name, value)
elif var_name in self._variables:
var = self._variables[var_name]
try:
new_value = var["type"](value)
if new_value > var["max"]:
var["value"] = var["max"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using max value {var['max']}"
)
elif new_value < var["min"]:
var["value"] = var["min"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using min value {var['min']}"

if var["type"] is not bool:
try:
new_value = var["type"](value)
if new_value > var["max"]:
var["value"] = var["max"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using max value {var['max']}"
)
elif new_value < var["min"]:
var["value"] = var["min"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using min value {var['min']}"
)
else:
var["value"] = new_value

except ValueError as e:
raise ConfigValidationError(
f"Variable {var_name!r} with value {value!r} has "
f"incorrect type."
f" Expected type {var['type'].__name__!r} but received"
f" value {value!r} of type {type(value).__name__!r}"
) from e
else:
if not isinstance(var["value"], bool):
raise ValueError(
f"Variable {var_name} expected type {var['type']}"
f" received value `{value}` of type `{type(value)}`"
)
else:
var["value"] = new_value

except ValueError as e:
raise ConfigValidationError(
f"Variable {var_name!r} with value {value!r} has incorrect type."
f" Expected type {var['type'].__name__!r} but received"
f" value {value!r} of type {type(value).__name__!r}"
) from e
# When config is first read, `value` is a string
# that's either "False" or "True",
# but since bool("False") is True we need to convert it to bool.
if not isinstance(value, bool):
value = str(value).lower() != "false"

var["value"] = var["type"](value)
else:
raise ConfigValidationError(
f"Variable {var_name!r} not found in {self.name!r} analysis module"
Expand All @@ -210,6 +260,32 @@ def inversion(self, value: int) -> None:
def get_truncation(self) -> float:
return self.get_variable_value("ENKF_TRUNCATION")

def localization(self) -> bool:
return bool(self.get_variable_value("LOCALIZATION"))

def localization_correlation_threshold(self, ensemble_size: int) -> float:
return correlation_threshold(
ensemble_size, self.get_variable_value("LOCALIZATION_CORRELATION_THRESHOLD")
)

def get_steplength(self, iteration_nr: int) -> float:
"""
This is an implementation of Eq. (49), which calculates a suitable
step length for the update step, from the book:
Geir Evensen, Formulating the history matching problem with
consistent error statistics, Computational Geosciences (2021) 25:945 –970
Function not really used moved from C to keep the class interface consistent
should be investigated for possible removal.
"""
min_step_length = self.get_variable_value("IES_MIN_STEPLENGTH")
max_step_length = self.get_variable_value("IES_MAX_STEPLENGTH")
dec_step_length = self.get_variable_value("IES_DEC_STEPLENGTH")
step_length = min_step_length + (max_step_length - min_step_length) * pow(
2, -(iteration_nr - 1) / (dec_step_length - 1)
)
return step_length

def __repr__(self) -> str:
return f"AnalysisModule(name = {self.name})"

Expand Down
19 changes: 19 additions & 0 deletions src/ert/gui/ertwidgets/analysismodulevariablespanel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
QWidget,
)

from ert.config.analysis_module import correlation_threshold
from ert.gui.ertwidgets.models.analysismodulevariablesmodel import (
AnalysisModuleVariablesModel,
)
Expand Down Expand Up @@ -41,10 +42,16 @@ def __init__(self, analysis_module_name: str, facade: LibresFacade):
variable_type = analysis_module_variables_model.getVariableType(
variable_name
)

variable_value = analysis_module_variables_model.getVariableValue(
self.facade, self._analysis_module_name, variable_name
)

if variable_name == "LOCALIZATION_CORRELATION_THRESHOLD":
variable_value = correlation_threshold(
self.facade.get_ensemble_size(), variable_value
)

label_name = analysis_module_variables_model.getVariableLabelName(
variable_name
)
Expand Down Expand Up @@ -123,6 +130,17 @@ def __init__(self, analysis_module_name: str, facade: LibresFacade):
lambda value: self.update_truncation_spinners(value, truncation_spinner)
)

localization_checkbox = self.widget_from_layout(layout, "LOCALIZATION")
localization_correlation_spinner = self.widget_from_layout(
layout, "LOCALIZATION_CORRELATION_THRESHOLD"
)
localization_correlation_spinner.setEnabled(localization_checkbox.isChecked())
localization_checkbox.stateChanged.connect(
lambda localization_is_on: localization_correlation_spinner.setEnabled(True)
if localization_is_on
else localization_correlation_spinner.setEnabled(False)
)

self.setLayout(layout)
self.blockSignals(False)

Expand Down Expand Up @@ -172,6 +190,7 @@ def createSpinBox(
def createCheckBox(self, variable_name, variable_value, variable_type):
spinner = QCheckBox()
spinner.setChecked(variable_value)
spinner.setObjectName(variable_name)
spinner.clicked.connect(
partial(self.valueChanged, variable_name, variable_type, spinner)
)
Expand Down
Loading

0 comments on commit d80f169

Please sign in to comment.