Skip to content

Commit

Permalink
Added option to specify GEN_OBS nodes in the form nodename:index
Browse files Browse the repository at this point in the history
This makes it possible to specify individual observations from an GEN_OBS node
in localisation.
Add tests to verify active param and active obs in ministeps
Updated integration tests with consistency check for active obs and param
  • Loading branch information
oddvarlia committed Jan 12, 2022
1 parent b6ecec4 commit a6a1410
Show file tree
Hide file tree
Showing 427 changed files with 872 additions and 57 deletions.
17 changes: 8 additions & 9 deletions semeio/workflows/localisation/local_config_script.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
from ert_shared.libres_facade import LibresFacade
from ert_shared.plugins.plugin_manager import hook_implementation

import semeio.workflows.localisation.local_script_lib as local
from semeio.communication import SemeioScript
from semeio.workflows.localisation.localisation_config import LocalisationConfig
from semeio.workflows.localisation.localisation_config import (
LocalisationConfig,
get_max_gen_obs_size_for_expansion,
)


class LocalisationConfigJob(SemeioScript):
def run(self, *args):
ert = self.ert()
facade = LibresFacade(self.ert())

# Clear all correlations
local.clear_correlations(ert)

# Read yml file with specifications
config_dict = local.read_localisation_config(args)

# Get all observations from ert instance
obs_keys = [
facade.get_observation_key(nr)
for nr, _ in enumerate(facade.get_observations())
]
expand_gen_obs_max_size = get_max_gen_obs_size_for_expansion(config_dict)
obs_keys = local.get_obs_from_ert(ert, expand_gen_obs_max_size)

ert_parameters = local.get_param_from_ert(ert.ensembleConfig())

Expand All @@ -35,6 +34,7 @@ def run(self, *args):
ert_parameters.to_dict(),
ert.getLocalConfig(),
ert.ensembleConfig(),
ert.getObservations(),
ert.eclConfig().getGrid(),
)

Expand Down Expand Up @@ -363,7 +363,6 @@ def run(self, *args):
length and the first value in the **scalingfactors** list corresponds to
the first segment number in the **active_segments** list and so on.
"""


Expand Down
72 changes: 66 additions & 6 deletions semeio/workflows/localisation/local_script_lib.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pylint: disable=W0201
# pylint: disable=C0302
import math
import yaml
import cwrap
Expand All @@ -15,13 +16,16 @@
from ecl.eclfile import Ecl3DKW
from ecl.ecl_type import EclDataType
from ecl.grid.ecl_grid import EclGrid

from res.enkf.enums.ert_impl_type_enum import ErtImplType
from res.enkf.enums.enkf_var_type_enum import EnkfVarType
from res.enkf import EnkfObservationImplementationType

from semeio.workflows.localisation.localisation_debug_settings import (
LogLevel,
debug_print,
)
from ert_shared.libres_facade import LibresFacade


@dataclass
Expand Down Expand Up @@ -524,11 +528,12 @@ def add_ministeps(
ert_param_dict,
ert_local_config,
ert_ensemble_config,
ert_obs,
grid_for_field,
):
# pylint: disable-msg=too-many-branches
# pylint: disable-msg=R0915

# pylint: disable-msg=R1702
debug_print("Add all ministeps:", LogLevel.LEVEL1, user_config.log_level)
ScalingValues.initialize()
# Read all region files used in correlation groups,
Expand All @@ -539,6 +544,7 @@ def add_ministeps(
)

for count, corr_spec in enumerate(user_config.correlations):

ministep_name = corr_spec.name
ministep = ert_local_config.createMinistep(ministep_name)
debug_print(
Expand All @@ -549,7 +555,7 @@ def add_ministeps(
obs_group_name = ministep_name + "_obs_group"
obs_group = ert_local_config.createObsdata(obs_group_name)

obs_list = corr_spec.obs_group.result_items
obs_dict = Parameters.from_list(corr_spec.obs_group.result_items).to_dict()
param_dict = Parameters.from_list(corr_spec.param_group.result_items).to_dict()

# Setup model parameter group
Expand Down Expand Up @@ -712,12 +718,38 @@ def add_ministeps(
user_config.log_level,
)

# Setup observation group
for obs_name in obs_list:
# Setup observation group. For GEN_OBS type
# the observation specification can be of the form obs_node_name:index
# if individual observations from a GEN_OBS node is chosen or
# only obs_node_name if all observations in GEN_OBS is active.
obs_type = EnkfObservationImplementationType.GEN_OBS
key_list_gen_obs = ert_obs.getTypedKeylist(obs_type)
for obs_node_name, obs_index_list in obs_dict.items():
obs_group.addNode(obs_node_name)
debug_print(
f"Add obs node: {obs_name}", LogLevel.LEVEL2, user_config.log_level
f"Add obs node: {obs_node_name}", LogLevel.LEVEL2, user_config.log_level
)
obs_group.addNode(obs_name)
if obs_node_name in key_list_gen_obs:
# An observation node of type GEN_OBS
if len(obs_index_list) > 0:
active_obs_list = obs_group.getActiveList(obs_node_name)
if len(obs_index_list) > 50:
debug_print(
f"More than 50 active obs for {obs_node_name}",
LogLevel.LEVEL3,
user_config.log_level,
)

for string_index in obs_index_list:
index = int(string_index)
if len(obs_index_list) <= 50:
debug_print(
f"Active obs for {obs_node_name} index: {index}",
LogLevel.LEVEL3,
user_config.log_level,
)

active_obs_list.addActiveIndex(index)

# Setup ministep
debug_print(
Expand Down Expand Up @@ -853,3 +885,31 @@ def write_qc_parameter(
grid.write_grdecl(scaling_kw, file)
# Increase parameter number to define unique parameter name
cls.scaling_param_number = cls.scaling_param_number + 1


def get_obs_from_ert(ert, expand_gen_obs_max_size):
facade = LibresFacade(ert)
ert_obs = facade.get_observations()
obs_keys = []
if expand_gen_obs_max_size == 0:
obs_keys = [facade.get_observation_key(nr) for nr, _ in enumerate(ert_obs)]
return obs_keys

for nr, _ in enumerate(ert_obs):
key = facade.get_observation_key(nr)
impl_type = facade.get_impl_type_name_for_obs_key(key)
if impl_type == "GEN_OBS":
obs_vector = ert_obs[key]
timestep = obs_vector.activeStep()
obs_node = obs_vector.getNode(timestep)
data_size = obs_node.getSize()
if data_size <= expand_gen_obs_max_size:
obs_key_with_index_list = [
key + ":" + str(item) for item in range(data_size)
]
obs_keys.extend(obs_key_with_index_list)
else:
obs_keys.append(key)
else:
obs_keys.append(key)
return obs_keys
38 changes: 38 additions & 0 deletions semeio/workflows/localisation/localisation_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,28 @@ def validate_surface_scale(cls, value):
)


class MaxGenObsSize(PydanticBaseModel):
"""
max_gen_obs_size: Integer >=0. Default: 0
If it is > 0, it defines that all GEN_OBS observations is
expanded into the form nodename:index. The user
must specify GEN_OBS type observations in
the form nodename:index or nodename:* if
all observations for a GEN_OBS node is used.
The max_gen_obs_size value is a threshold value.
If a GEN_OBS node has more observations than
max_gen_obs_size specified by the user,
the obs node is not expanded and the
user also must specify the obs node only by its
nodename, not in expanded form. Typical use of this is
to let nodes containing moderate number of observations
be expanded, while nodes having large number of
observations are not expanded.
"""

max_gen_obs_size: Optional[conint(ge=0)] = 0


class LocalisationConfig(BaseModel):
"""
observations: A list of observations from ERT in format nodename
Expand All @@ -309,13 +331,23 @@ class LocalisationConfig(BaseModel):
log_level: Integer defining how much log output to write to screen
write_scaling_factors: Turn on writing calculated scaling parameters to file.
Possible values: True/False. Default: False
max_gen_obs_size: Integer defining max size for a GEN_OBS node to
be expanded in the form nodename:index.
If the observation node of type GEN_OBS has more
observations than this number, it can only be specified with
node name which then represents the whole set of
observations for the node.
Possible values: Integers >= 0
Default: 0 which means that GEN_OBS nodes are specified
with node name only.
"""

observations: List[str]
parameters: List[str]
correlations: List[CorrelationConfig]
log_level: Optional[conint(ge=0, le=5)] = 1
write_scaling_factors: Optional[bool] = False
max_gen_obs_size: Optional[conint(ge=0)] = 0

@validator("log_level")
def validate_log_level(cls, level):
Expand Down Expand Up @@ -348,3 +380,9 @@ def _check_specification(items_to_add, items_to_remove, valid_items):
added_items = added_items.difference(removed_items)
added_items = list(added_items)
return sorted(added_items)


def get_max_gen_obs_size_for_expansion(config_dict):
tmp_config = MaxGenObsSize(**config_dict)
value = tmp_config.max_gen_obs_size
return value
15 changes: 15 additions & 0 deletions tests/jobs/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,18 @@ def setup_poly_ert(tmpdir, test_data_root):

yield
os.chdir(cwd)


@pytest.fixture()
def setup_poly_gen_param_ert(tmpdir, test_data_root):
cwd = os.getcwd()
tmpdir.chdir()
test_data_dir = os.path.join(test_data_root, "poly_gen_param")
shutil.copytree(test_data_dir, "test_data")
os.chdir(os.path.join("test_data"))

res_config = ResConfig("poly.ert")

yield res_config

os.chdir(cwd)
Loading

0 comments on commit a6a1410

Please sign in to comment.