diff --git a/semeio/workflows/localisation/local_config_script.py b/semeio/workflows/localisation/local_config_script.py index b89416520..d8d473bef 100644 --- a/semeio/workflows/localisation/local_config_script.py +++ b/semeio/workflows/localisation/local_config_script.py @@ -1,26 +1,33 @@ -from ert_shared.libres_facade import LibresFacade +# from ert_shared.libres_facade import LibresFacade from ert_shared.plugins.plugin_manager import hook_implementation import semeio.workflows.localisation.local_script_lib as local from semeio.communication import SemeioScript -from semeio.workflows.localisation.localisation_config import LocalisationConfig +from semeio.workflows.localisation.localisation_config import ( + LocalisationConfig, + get_max_gen_obs_size_for_expansion, +) class LocalisationConfigJob(SemeioScript): def run(self, *args): ert = self.ert() - facade = LibresFacade(self.ert()) + # facade = LibresFacade(self.ert()) # Clear all correlations local.clear_correlations(ert) # Read yml file with specifications config_dict = local.read_localisation_config(args) + # print(f"config_dict:\n {config_dict}") + # Get all observations from ert instance - obs_keys = [ - facade.get_observation_key(nr) - for nr, _ in enumerate(facade.get_observations()) - ] + # obs_keys = [ + # facade.get_observation_key(nr) + # for nr, _ in enumerate(facade.get_observations()) + # ] + expand_gen_obs_max_size = get_max_gen_obs_size_for_expansion(config_dict) + obs_keys = local.get_obs_from_ert(ert, expand_gen_obs_max_size) ert_parameters = local.get_param_from_ert(ert.ensembleConfig()) @@ -35,6 +42,7 @@ def run(self, *args): ert_parameters.to_dict(), ert.getLocalConfig(), ert.ensembleConfig(), + ert.getObservations(), ert.eclConfig().getGrid(), ) diff --git a/semeio/workflows/localisation/local_script_lib.py b/semeio/workflows/localisation/local_script_lib.py index f4e02265d..9ac5865c0 100644 --- a/semeio/workflows/localisation/local_script_lib.py +++ b/semeio/workflows/localisation/local_script_lib.py @@ -15,13 +15,18 @@ from ecl.eclfile import Ecl3DKW from ecl.ecl_type import EclDataType from ecl.grid.ecl_grid import EclGrid + from res.enkf.enums.ert_impl_type_enum import ErtImplType from res.enkf.enums.enkf_var_type_enum import EnkfVarType +from res.enkf import EnkfObservationImplementationType + +# from res.enkf.enums.active_mode_enum import ActiveMode from semeio.workflows.localisation.localisation_debug_settings import ( LogLevel, debug_print, ) +from ert_shared.libres_facade import LibresFacade @dataclass @@ -524,11 +529,12 @@ def add_ministeps( ert_param_dict, ert_local_config, ert_ensemble_config, + ert_obs, grid_for_field, ): # pylint: disable-msg=too-many-branches # pylint: disable-msg=R0915 - + # pylint: disable-msg=R1702 debug_print("Add all ministeps:", LogLevel.LEVEL1, user_config.log_level) ScalingValues.initialize() # Read all region files used in correlation groups, @@ -539,6 +545,7 @@ def add_ministeps( ) for count, corr_spec in enumerate(user_config.correlations): + ministep_name = corr_spec.name ministep = ert_local_config.createMinistep(ministep_name) debug_print( @@ -551,7 +558,8 @@ def add_ministeps( obs_group_name = ministep_name + "_obs_group" obs_group = ert_local_config.createObsdata(obs_group_name) - obs_list = corr_spec.obs_group.result_items + # obs_list = corr_spec.obs_group.result_items + obs_dict = Parameters.from_list(corr_spec.obs_group.result_items).to_dict() param_dict = Parameters.from_list(corr_spec.param_group.result_items).to_dict() # Setup model parameter group @@ -714,12 +722,40 @@ def add_ministeps( user_config.log_level, ) - # Setup observation group - for obs_name in obs_list: + # Setup observation group. For GEN_OBS type + # the observation specification can be of the form obs_node_name:index + # if individual observations from a GEN_OBS node is chosen or + # only obs_node_name if all observations in GEN_OBS is active. + obs_type = EnkfObservationImplementationType.GEN_OBS + key_list_gen_obs = ert_obs.getTypedKeylist(obs_type) + # print(f"key_list_gen_obs: {key_list_gen_obs}") + # print(f"obs_dict: {obs_dict}") + for obs_node_name, obs_index_list in obs_dict.items(): + obs_group.addNode(obs_node_name) debug_print( - f"Add obs node: {obs_name}", LogLevel.LEVEL2, user_config.log_level + f"Add obs node: {obs_node_name}", LogLevel.LEVEL2, user_config.log_level ) - obs_group.addNode(obs_name) + if obs_node_name in key_list_gen_obs: + # An observation node of type GEN_OBS + if len(obs_index_list) > 0: + active_obs_list = obs_group.getActiveList(obs_node_name) + if len(obs_index_list) > 50: + debug_print( + f"More than 50 active obs for {obs_node_name}", + LogLevel.LEVEL3, + user_config.log_level, + ) + + for string_index in obs_index_list: + index = int(string_index) + if len(obs_index_list) <= 50: + debug_print( + f"Active obs for {obs_node_name} index: {index}", + LogLevel.LEVEL3, + user_config.log_level, + ) + + active_obs_list.addActiveIndex(index) # Setup ministep debug_print( @@ -856,3 +892,37 @@ def write_qc_parameter( grid.write_grdecl(scaling_kw, file) # Increase parameter number to define unique parameter name cls.scaling_param_number = cls.scaling_param_number + 1 + + +def get_obs_from_ert(ert, expand_gen_obs_max_size): + facade = LibresFacade(ert) + ert_obs = facade.get_observations() + obs_keys = [] + if expand_gen_obs_max_size == 0: + obs_keys = [facade.get_observation_key(nr) for nr, _ in enumerate(ert_obs)] + return obs_keys + + for nr, _ in enumerate(ert_obs): + key = facade.get_observation_key(nr) + impl_type = facade.get_impl_type_name_for_obs_key(key) + print(f"obs key: {key} obs type: {impl_type}") + # keylist_gen_obs =ert_obs.getTypedKeylist(obs_type) + if impl_type == "GEN_OBS": + obs_vector = ert_obs[key] + print(f"obs_vector: {obs_vector}") + timestep = obs_vector.activeStep() + obs_node = obs_vector.getNode(timestep) + print(f"obs_node: {obs_node}") + data_size = obs_node.getSize() + print(f"data_size: {data_size}") + if data_size <= expand_gen_obs_max_size: + obs_key_with_index_list = [ + key + ":" + str(item) for item in range(data_size) + ] + obs_keys.extend(obs_key_with_index_list) + else: + obs_keys.append(key) + else: + obs_keys.append(key) + print(f"obs keys : {obs_keys}") + return obs_keys diff --git a/semeio/workflows/localisation/localisation_config.py b/semeio/workflows/localisation/localisation_config.py index 84d44fc6e..c80e6134e 100644 --- a/semeio/workflows/localisation/localisation_config.py +++ b/semeio/workflows/localisation/localisation_config.py @@ -327,6 +327,19 @@ def validate_surface_scale(cls, value): ) +class MaxGenObsSize(PydanticBaseModel): + """ + max_gen_obs_size: Integer >=0. Default: 0 + If it is > 0, it defines that all GEN_OBS observations is + expanded into the form nodename:index and that the + user must specify GEN_OBS type observations in + the form nodename:index or nodename:* if + all observations for a GEN_OBS node is used. + """ + + max_gen_obs_size: Optional[conint(ge=0)] = 0 + + class LocalisationConfig(BaseModel): """ observations: A list of observations from ERT in format nodename @@ -339,6 +352,15 @@ class LocalisationConfig(BaseModel): log_level: Integer defining how much log output to write to screen write_scaling_factors: Turn on writing calculated scaling parameters to file. Possible values: True/False. Default: False + max_gen_obs_size: Integer defining max size for a GEN_OBS node to + be expanded in the form nodename:index. + If the observation node of type GEN_OBS has more + observations than this number, it can only be specified with + node name which then represents the whole set of + observations for the node. + Possible values: Integers >= 0 + Default: 0 which means that GEN_OBS nodes are specified + with node name only. """ observations: List[str] @@ -346,6 +368,7 @@ class LocalisationConfig(BaseModel): correlations: List[CorrelationConfig] log_level: Optional[conint(ge=0, le=5)] = 1 write_scaling_factors: Optional[bool] = False + max_gen_obs_size: Optional[conint(ge=0)] = 0 @validator("log_level") def validate_log_level(cls, level): @@ -378,3 +401,9 @@ def _check_specification(items_to_add, items_to_remove, valid_items): added_items = added_items.difference(removed_items) added_items = list(added_items) return sorted(added_items) + + +def get_max_gen_obs_size_for_expansion(config_dict): + tmp_config = MaxGenObsSize(**config_dict) + value = tmp_config.max_gen_obs_size + return value diff --git a/tests/jobs/localisation/test_configs/test_config.py b/tests/jobs/localisation/test_configs/test_config.py index a3cb448dd..6fd053797 100644 --- a/tests/jobs/localisation/test_configs/test_config.py +++ b/tests/jobs/localisation/test_configs/test_config.py @@ -12,6 +12,15 @@ ERT_OBS = ["OBS1", "OBS2", "OBS11", "OBS22", "OBS12", "OBS13", "OBS14", "OBS3"] +ERT_GEN_OBS = [ + "GENOBSA:0", + "GENOBSA:1", + "GENOBSA:2", + "GENOBSB:0", + "GENOBSB:1", + "GENOBSC:0", +] + ERT_PARAM = [ "PARAM_NODE1:PARAM1", "PARAM_NODE1:PARAM2", @@ -141,6 +150,52 @@ def test_simple_config(param_group_add, expected): assert sorted(conf.correlations[0].param_group.result_items) == sorted(expected) +@pytest.mark.parametrize( + "obs_group_add, obs_group_remove, expected", + [ + ( + "GENOBS*", + [], + [ + "GENOBSA:0", + "GENOBSA:1", + "GENOBSA:2", + "GENOBSB:0", + "GENOBSB:1", + "GENOBSC:0", + ], + ), + ( + ["GENOBSB:*"], + ["GENOBSB:0"], + ["GENOBSB:1"], + ), + ( + ["*"], + ["*B:0"], + ["GENOBSA:0", "GENOBSA:1", "GENOBSA:2", "GENOBSB:1", "GENOBSC:0"], + ), + ], +) +def test_gen_obs_config(obs_group_add, obs_group_remove, expected): + data = { + "log_level": 2, + "max_gen_obs_size": 10, + "correlations": [ + { + "name": "some_name", + "obs_group": { + "add": obs_group_add, + "remove": obs_group_remove, + }, + "param_group": {"add": ["PARAM_NODE1:*"]}, + } + ], + } + conf = LocalisationConfig(observations=ERT_GEN_OBS, parameters=ERT_PARAM, **data) + assert sorted(conf.correlations[0].obs_group.result_items) == sorted(expected) + + @pytest.mark.parametrize( "obs_group_add, param_group_add, param_group_remove, expected_error", [ diff --git a/tests/jobs/localisation/test_integration.py b/tests/jobs/localisation/test_integration.py index 340b654a7..5717cf359 100644 --- a/tests/jobs/localisation/test_integration.py +++ b/tests/jobs/localisation/test_integration.py @@ -518,3 +518,30 @@ def test_localisation_field2(setup_poly_ert): with open("local_config.yaml", "w") as fout: yaml.dump(config, fout) LocalisationConfigJob(ert).run("local_config.yaml") + + +def test_localisation_gen_obs( + setup_poly_ert, +): + res_config = ResConfig("poly.ert") + ert = EnKFMain(res_config) + config = { + "log_level": 2, + "max_gen_obs_size": 1000, + "correlations": [ + { + "name": "CORR1", + "obs_group": { + "add": ["POLY_OBS:*"], + # "add": ["POLY_OBS"], + }, + "param_group": { + "add": ["*"], + }, + }, + ], + } + + with open("local_config.yaml", "w", encoding="utf-8") as fout: + yaml.dump(config, fout) + LocalisationConfigJob(ert).run("local_config.yaml")