Skip to content

Commit

Permalink
Added option to specify GEN_OBS nodes in the form nodename:index
Browse files Browse the repository at this point in the history
  • Loading branch information
oddvarlia committed Oct 28, 2021
1 parent c5e9f18 commit f704553
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 13 deletions.
22 changes: 15 additions & 7 deletions semeio/workflows/localisation/local_config_script.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,33 @@
from ert_shared.libres_facade import LibresFacade
# from ert_shared.libres_facade import LibresFacade
from ert_shared.plugins.plugin_manager import hook_implementation

import semeio.workflows.localisation.local_script_lib as local
from semeio.communication import SemeioScript
from semeio.workflows.localisation.localisation_config import LocalisationConfig
from semeio.workflows.localisation.localisation_config import (
LocalisationConfig,
get_max_gen_obs_size_for_expansion,
)


class LocalisationConfigJob(SemeioScript):
def run(self, *args):
ert = self.ert()
facade = LibresFacade(self.ert())
# facade = LibresFacade(self.ert())
# Clear all correlations
local.clear_correlations(ert)

# Read yml file with specifications
config_dict = local.read_localisation_config(args)

# print(f"config_dict:\n {config_dict}")

# Get all observations from ert instance
obs_keys = [
facade.get_observation_key(nr)
for nr, _ in enumerate(facade.get_observations())
]
# obs_keys = [
# facade.get_observation_key(nr)
# for nr, _ in enumerate(facade.get_observations())
# ]
expand_gen_obs_max_size = get_max_gen_obs_size_for_expansion(config_dict)
obs_keys = local.get_obs_from_ert(ert, expand_gen_obs_max_size)

ert_parameters = local.get_param_from_ert(ert.ensembleConfig())

Expand All @@ -35,6 +42,7 @@ def run(self, *args):
ert_parameters.to_dict(),
ert.getLocalConfig(),
ert.ensembleConfig(),
ert.getObservations(),
ert.eclConfig().getGrid(),
)

Expand Down
82 changes: 76 additions & 6 deletions semeio/workflows/localisation/local_script_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@
from ecl.eclfile import Ecl3DKW
from ecl.ecl_type import EclDataType
from ecl.grid.ecl_grid import EclGrid

from res.enkf.enums.ert_impl_type_enum import ErtImplType
from res.enkf.enums.enkf_var_type_enum import EnkfVarType
from res.enkf import EnkfObservationImplementationType

# from res.enkf.enums.active_mode_enum import ActiveMode

from semeio.workflows.localisation.localisation_debug_settings import (
LogLevel,
debug_print,
)
from ert_shared.libres_facade import LibresFacade


@dataclass
Expand Down Expand Up @@ -524,11 +529,12 @@ def add_ministeps(
ert_param_dict,
ert_local_config,
ert_ensemble_config,
ert_obs,
grid_for_field,
):
# pylint: disable-msg=too-many-branches
# pylint: disable-msg=R0915

# pylint: disable-msg=R1702
debug_print("Add all ministeps:", LogLevel.LEVEL1, user_config.log_level)
ScalingValues.initialize()
# Read all region files used in correlation groups,
Expand All @@ -539,6 +545,7 @@ def add_ministeps(
)

for count, corr_spec in enumerate(user_config.correlations):

ministep_name = corr_spec.name
ministep = ert_local_config.createMinistep(ministep_name)
debug_print(
Expand All @@ -551,7 +558,8 @@ def add_ministeps(
obs_group_name = ministep_name + "_obs_group"
obs_group = ert_local_config.createObsdata(obs_group_name)

obs_list = corr_spec.obs_group.result_items
# obs_list = corr_spec.obs_group.result_items
obs_dict = Parameters.from_list(corr_spec.obs_group.result_items).to_dict()
param_dict = Parameters.from_list(corr_spec.param_group.result_items).to_dict()

# Setup model parameter group
Expand Down Expand Up @@ -714,12 +722,40 @@ def add_ministeps(
user_config.log_level,
)

# Setup observation group
for obs_name in obs_list:
# Setup observation group. For GEN_OBS type
# the observation specification can be of the form obs_node_name:index
# if individual observations from a GEN_OBS node is chosen or
# only obs_node_name if all observations in GEN_OBS is active.
obs_type = EnkfObservationImplementationType.GEN_OBS
key_list_gen_obs = ert_obs.getTypedKeylist(obs_type)
# print(f"key_list_gen_obs: {key_list_gen_obs}")
# print(f"obs_dict: {obs_dict}")
for obs_node_name, obs_index_list in obs_dict.items():
obs_group.addNode(obs_node_name)
debug_print(
f"Add obs node: {obs_name}", LogLevel.LEVEL2, user_config.log_level
f"Add obs node: {obs_node_name}", LogLevel.LEVEL2, user_config.log_level
)
obs_group.addNode(obs_name)
if obs_node_name in key_list_gen_obs:
# An observation node of type GEN_OBS
if len(obs_index_list) > 0:
active_obs_list = obs_group.getActiveList(obs_node_name)
if len(obs_index_list) > 50:
debug_print(
f"More than 50 active obs for {obs_node_name}",
LogLevel.LEVEL3,
user_config.log_level,
)

for string_index in obs_index_list:
index = int(string_index)
if len(obs_index_list) <= 50:
debug_print(
f"Active obs for {obs_node_name} index: {index}",
LogLevel.LEVEL3,
user_config.log_level,
)

active_obs_list.addActiveIndex(index)

# Setup ministep
debug_print(
Expand Down Expand Up @@ -856,3 +892,37 @@ def write_qc_parameter(
grid.write_grdecl(scaling_kw, file)
# Increase parameter number to define unique parameter name
cls.scaling_param_number = cls.scaling_param_number + 1


def get_obs_from_ert(ert, expand_gen_obs_max_size):
facade = LibresFacade(ert)
ert_obs = facade.get_observations()
obs_keys = []
if expand_gen_obs_max_size == 0:
obs_keys = [facade.get_observation_key(nr) for nr, _ in enumerate(ert_obs)]
return obs_keys

for nr, _ in enumerate(ert_obs):
key = facade.get_observation_key(nr)
impl_type = facade.get_impl_type_name_for_obs_key(key)
print(f"obs key: {key} obs type: {impl_type}")
# keylist_gen_obs =ert_obs.getTypedKeylist(obs_type)
if impl_type == "GEN_OBS":
obs_vector = ert_obs[key]
print(f"obs_vector: {obs_vector}")
timestep = obs_vector.activeStep()
obs_node = obs_vector.getNode(timestep)
print(f"obs_node: {obs_node}")
data_size = obs_node.getSize()
print(f"data_size: {data_size}")
if data_size <= expand_gen_obs_max_size:
obs_key_with_index_list = [
key + ":" + str(item) for item in range(data_size)
]
obs_keys.extend(obs_key_with_index_list)
else:
obs_keys.append(key)
else:
obs_keys.append(key)
print(f"obs keys : {obs_keys}")
return obs_keys
29 changes: 29 additions & 0 deletions semeio/workflows/localisation/localisation_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,19 @@ def validate_surface_scale(cls, value):
)


class MaxGenObsSize(PydanticBaseModel):
"""
max_gen_obs_size: Integer >=0. Default: 0
If it is > 0, it defines that all GEN_OBS observations is
expanded into the form nodename:index and that the
user must specify GEN_OBS type observations in
the form nodename:index or nodename:* if
all observations for a GEN_OBS node is used.
"""

max_gen_obs_size: Optional[conint(ge=0)] = 0


class LocalisationConfig(BaseModel):
"""
observations: A list of observations from ERT in format nodename
Expand All @@ -339,13 +352,23 @@ class LocalisationConfig(BaseModel):
log_level: Integer defining how much log output to write to screen
write_scaling_factors: Turn on writing calculated scaling parameters to file.
Possible values: True/False. Default: False
max_gen_obs_size: Integer defining max size for a GEN_OBS node to
be expanded in the form nodename:index.
If the observation node of type GEN_OBS has more
observations than this number, it can only be specified with
node name which then represents the whole set of
observations for the node.
Possible values: Integers >= 0
Default: 0 which means that GEN_OBS nodes are specified
with node name only.
"""

observations: List[str]
parameters: List[str]
correlations: List[CorrelationConfig]
log_level: Optional[conint(ge=0, le=5)] = 1
write_scaling_factors: Optional[bool] = False
max_gen_obs_size: Optional[conint(ge=0)] = 0

@validator("log_level")
def validate_log_level(cls, level):
Expand Down Expand Up @@ -378,3 +401,9 @@ def _check_specification(items_to_add, items_to_remove, valid_items):
added_items = added_items.difference(removed_items)
added_items = list(added_items)
return sorted(added_items)


def get_max_gen_obs_size_for_expansion(config_dict):
tmp_config = MaxGenObsSize(**config_dict)
value = tmp_config.max_gen_obs_size
return value
55 changes: 55 additions & 0 deletions tests/jobs/localisation/test_configs/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@


ERT_OBS = ["OBS1", "OBS2", "OBS11", "OBS22", "OBS12", "OBS13", "OBS14", "OBS3"]
ERT_GEN_OBS = [
"GENOBSA:0",
"GENOBSA:1",
"GENOBSA:2",
"GENOBSB:0",
"GENOBSB:1",
"GENOBSC:0",
]

ERT_PARAM = [
"PARAM_NODE1:PARAM1",
"PARAM_NODE1:PARAM2",
Expand Down Expand Up @@ -141,6 +150,52 @@ def test_simple_config(param_group_add, expected):
assert sorted(conf.correlations[0].param_group.result_items) == sorted(expected)


@pytest.mark.parametrize(
"obs_group_add, obs_group_remove, expected",
[
(
"GENOBS*",
[],
[
"GENOBSA:0",
"GENOBSA:1",
"GENOBSA:2",
"GENOBSB:0",
"GENOBSB:1",
"GENOBSC:0",
],
),
(
["GENOBSB:*"],
["GENOBSB:0"],
["GENOBSB:1"],
),
(
["*"],
["*B:0"],
["GENOBSA:0", "GENOBSA:1", "GENOBSA:2", "GENOBSB:1", "GENOBSC:0"],
),
],
)
def test_gen_obs_config(obs_group_add, obs_group_remove, expected):
data = {
"log_level": 2,
"max_gen_obs_size": 10,
"correlations": [
{
"name": "some_name",
"obs_group": {
"add": obs_group_add,
"remove": obs_group_remove,
},
"param_group": {"add": ["PARAM_NODE1:*"]},
}
],
}
conf = LocalisationConfig(observations=ERT_GEN_OBS, parameters=ERT_PARAM, **data)
assert sorted(conf.correlations[0].obs_group.result_items) == sorted(expected)


@pytest.mark.parametrize(
"obs_group_add, param_group_add, param_group_remove, expected_error",
[
Expand Down
27 changes: 27 additions & 0 deletions tests/jobs/localisation/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,3 +518,30 @@ def test_localisation_field2(setup_poly_ert):
with open("local_config.yaml", "w") as fout:
yaml.dump(config, fout)
LocalisationConfigJob(ert).run("local_config.yaml")


def test_localisation_gen_obs(
setup_poly_ert,
):
res_config = ResConfig("poly.ert")
ert = EnKFMain(res_config)
config = {
"log_level": 2,
"max_gen_obs_size": 1000,
"correlations": [
{
"name": "CORR1",
"obs_group": {
"add": ["POLY_OBS:*"],
# "add": ["POLY_OBS"],
},
"param_group": {
"add": ["*"],
},
},
],
}

with open("local_config.yaml", "w", encoding="utf-8") as fout:
yaml.dump(config, fout)
LocalisationConfigJob(ert).run("local_config.yaml")

0 comments on commit f704553

Please sign in to comment.