Skip to content

Commit

Permalink
Add tests to verify active param and active obs in ministeps
Browse files Browse the repository at this point in the history
  • Loading branch information
oddvarlia committed Nov 11, 2021
1 parent b3f3c7d commit fe66e39
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 29 deletions.
3 changes: 2 additions & 1 deletion semeio/workflows/localisation/local_config_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@ def run(self, *args):
ert.eclConfig().getGrid(),
)
if config.verify_active_parameters:
local.verify_ministep(
local.verify_ministep_active_param(
config.correlations,
ert.getLocalConfig(),
ert.ensembleConfig(),
ert_parameters.to_dict(),
)
local.verify_ministep_active_obs(config.correlations, ert)


DESCRIPTION = """
Expand Down
89 changes: 74 additions & 15 deletions semeio/workflows/localisation/local_script_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,6 @@ def add_ministeps(
obs_group_name = ministep_name + "_obs_group"
obs_group = ert_local_config.createObsdata(obs_group_name)

# obs_list = corr_spec.obs_group.result_items
obs_dict = Parameters.from_list(corr_spec.obs_group.result_items).to_dict()
param_dict = Parameters.from_list(corr_spec.param_group.result_items).to_dict()

Expand Down Expand Up @@ -728,8 +727,6 @@ def add_ministeps(
# only obs_node_name if all observations in GEN_OBS is active.
obs_type = EnkfObservationImplementationType.GEN_OBS
key_list_gen_obs = ert_obs.getTypedKeylist(obs_type)
# print(f"key_list_gen_obs: {key_list_gen_obs}")
# print(f"obs_dict: {obs_dict}")
for obs_node_name, obs_index_list in obs_dict.items():
obs_group.addNode(obs_node_name)
debug_print(
Expand Down Expand Up @@ -795,7 +792,7 @@ def get_corr_group_spec(correlations_spec_list, name):
return corr_spec


def verify_ministep(
def verify_ministep_active_param(
corr_spec_list, ert_local_config, ert_ensemble_config, ert_param_dict
):
"""
Expand All @@ -804,7 +801,7 @@ def verify_ministep(
Reports mismatch if found and silent if OK.
Used for test purpose.
"""
print("\nVerify ministep setup:")
print("\nVerify ministep setup for active parameters:")
updatestep = ert_local_config.getUpdatestep()
for ministep in updatestep:
print(f"Ministep: {ministep.name()}")
Expand Down Expand Up @@ -853,8 +850,9 @@ def verify_ministep(
active_index_list = active_list_obj.getActiveIndexList()
spec_index_list.sort()
active_index_list.sort()
print(f"Spec_index_list: {spec_index_list}")
print(f"Ministep index_list: {active_index_list}")
print(f" Param node: {node_name}")
print(f" Active param indices (user specified): {spec_index_list}")
print(f" Active param indices (ministep): {active_index_list}")
if len(spec_index_list) != len(active_index_list):
raise ValueError(
f"For ministep: {ministep.name()} the number of "
Expand All @@ -869,14 +867,75 @@ def verify_ministep(
err = True
if err:
raise ValueError(
f" In ministep: {ministep.name()} there is a "
"mismatch between specified "
"active parameters and active parameters in the ministep.\n"
f" For ministep: {ministep.name()} and "
f"parameter node: {node_name}:\n"
"Mismatch between specified active parameters "
f"and active parameters in the ministep.\n"
f"Specified: {spec_index_list}\n"
f"In ministep: {active_index_list}\n"
)


def verify_ministep_active_obs(corr_spec_list, ert):
# pylint: disable=R1702
"""
Script to verify that the local config matches the specified user config for
active observations.
Reports mismatch if found and silent if OK.
Used for test purpose.
"""
print("\nVerify ministep setup for active observations:")
facade = LibresFacade(ert)
ert_obs = facade.get_observations()
ert_local_config = ert.getLocalConfig()

updatestep = ert_local_config.getUpdatestep()
for ministep in updatestep:
print(f"Ministep: {ministep.name()}")
# User specification
corr_spec = get_corr_group_spec(corr_spec_list, ministep.name())
obs_dict = Parameters.from_list(corr_spec.obs_group.result_items).to_dict()

# Data from local config, only one obs group in a ministep here.
local_obs_data = ministep.getLocalObsData()
for obs_node in local_obs_data:
key = obs_node.key()
impl_type = facade.get_impl_type_name_for_obs_key(key)
if impl_type == "GEN_OBS":
active_list_obj = obs_node.getActiveList()
if active_list_obj.getMode() == ActiveMode.PARTLY_ACTIVE:
obs_vector = ert_obs[key]
# Always 1 timestep for a GEN_OBS
timestep = obs_vector.activeStep()
genobs_node = obs_vector.getNode(timestep)
data_size = genobs_node.getSize()
active_list_obj = obs_node.getActiveList()
active_index_list = active_list_obj.getActiveIndexList()
active_index_list.sort()

# From user specification
str_list = obs_dict[key]
spec_index_list = [int(str_list[i]) for i in range(len(str_list))]
spec_index_list.sort()
err = False
for nr, index in enumerate(active_index_list):
if index != spec_index_list[nr]:
err = True
if err:
raise ValueError(
f" For ministep: {ministep.name()} and "
f"observation node: {key}:\n"
"Mismatch between specified active observations and "
"active observations defined in the ministep.\n"
f"Specified: {spec_index_list}\n"
f"In ministep: {active_index_list}\n"
)
print(f" Obs node: {key}")
print(f" Full size of obs node: {data_size}")
print(f" Active obs indices (user specified): {spec_index_list}")
print(f" Active obs indices (ministep): {active_index_list}")


def clear_correlations(ert):
local_config = ert.getLocalConfig()
local_config.clear()
Expand Down Expand Up @@ -1002,16 +1061,16 @@ def get_obs_from_ert(ert, expand_gen_obs_max_size):
for nr, _ in enumerate(ert_obs):
key = facade.get_observation_key(nr)
impl_type = facade.get_impl_type_name_for_obs_key(key)
print(f"obs key: {key} obs type: {impl_type}")
# print(f"obs key: {key} obs type: {impl_type}")
# keylist_gen_obs =ert_obs.getTypedKeylist(obs_type)
if impl_type == "GEN_OBS":
obs_vector = ert_obs[key]
print(f"obs_vector: {obs_vector}")
# print(f"obs_vector: {obs_vector}")
timestep = obs_vector.activeStep()
obs_node = obs_vector.getNode(timestep)
print(f"obs_node: {obs_node}")
# print(f"obs_node: {obs_node}")
data_size = obs_node.getSize()
print(f"data_size: {data_size}")
# print(f"data_size: {data_size}")
if data_size <= expand_gen_obs_max_size:
obs_key_with_index_list = [
key + ":" + str(item) for item in range(data_size)
Expand All @@ -1021,5 +1080,5 @@ def get_obs_from_ert(ert, expand_gen_obs_max_size):
obs_keys.append(key)
else:
obs_keys.append(key)
print(f"obs keys : {obs_keys}")
# print(f"obs keys : {obs_keys}")
return obs_keys
2 changes: 1 addition & 1 deletion semeio/workflows/localisation/localisation_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ class LocalisationConfig(BaseModel):
correlations: List[CorrelationConfig]
log_level: Optional[conint(ge=0, le=5)] = 1
write_scaling_factors: Optional[bool] = False
verify_active_parameters: Optional[bool] = False
verify_active: Optional[bool] = False
max_gen_obs_size: Optional[conint(ge=0)] = 0

@validator("log_level")
Expand Down
108 changes: 96 additions & 12 deletions tests/jobs/localisation/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def test_localisation(setup_ert, obs_group_add, param_group_add, expected):
ert = EnKFMain(setup_ert)
config = {
"log_level": 4,
"verify_active": True,
"correlations": [
{
"name": "CORR1",
Expand Down Expand Up @@ -95,11 +96,12 @@ def test_localisation_gen_kw(setup_ert):
ert = EnKFMain(setup_ert, verbose=True)
config = {
"log_level": 4,
"verify_active_parameters": True,
"verify_active": True,
"max_gen_obs_size": 1000,
"correlations": [
{
"name": "CORR12",
"obs_group": {"add": "*"},
"obs_group": {"add": ["WPR_DIFF_1:0", "WPR_DIFF_1:3"]},
"param_group": {
"add": [
"SNAKE_OIL_PARAM:OP1_PERSISTENCE",
Expand All @@ -109,14 +111,17 @@ def test_localisation_gen_kw(setup_ert):
},
{
"name": "CORR3",
"obs_group": {"add": "*"},
"obs_group": {"add": "WPR_DIFF_1:2"},
"param_group": {
"add": "SNAKE_OIL_PARAM:OP1_DIVERGENCE_SCALE",
},
},
{
"name": "CORR4",
"obs_group": {"add": "*"},
"obs_group": {
"add": "*",
"remove": ["WPR_DIFF_1:1", "WPR_DIFF_1:0"],
},
"param_group": {
"add": "SNAKE_OIL_PARAM:OP1_OFFSET",
},
Expand Down Expand Up @@ -183,6 +188,7 @@ def test_localisation_gen_param(
ert = EnKFMain(res_config)
config = {
"log_level": 2,
"verify_active": True,
"correlations": [
{
"name": "CORR1",
Expand Down Expand Up @@ -241,6 +247,7 @@ def test_localisation_surf(
ert = EnKFMain(res_config)
config = {
"log_level": 3,
"verify_active": True,
"correlations": [
{
"name": "CORR1",
Expand Down Expand Up @@ -296,7 +303,6 @@ def test_localisation_field1(
values = np.zeros((nx, ny, nz), dtype=np.float32)
property_field.values = values + 0.1 * n
filename = pname + "_" + str(n) + ".roff"
print(f"Write file: {filename}")
property_field.to_file(filename, fformat="roff", name=pname)

fout.write(
Expand All @@ -310,6 +316,7 @@ def test_localisation_field1(
config = {
"log_level": 3,
"write_scaling_factors": True,
"verify_active": True,
"correlations": [
{
"name": "CORR1",
Expand Down Expand Up @@ -397,7 +404,6 @@ def create_box_grid_with_inactive_and_active_cells(
if has_inactive_values:
grid.inactivate_outside(polygon, force_close=True)

print(f" Write file: {output_grid_file}")
grid.to_file(output_grid_file, fformat="egrid")
return grid

Expand Down Expand Up @@ -438,7 +444,6 @@ def create_region_parameter(filename, grid):
else:
values[i, j, k] = 4
region_param.values = values
print(f"Write file: {filename}")
region_param.to_file(filename, fformat="grdecl", name=region_param_name)


Expand Down Expand Up @@ -467,9 +472,8 @@ def create_field_and_scaling_param_and_update_poly_ert(
values = np.zeros((nx, ny, nz), dtype=np.float32)
property_field.values = values + 0.1 * n
filename = property_name + "_" + str(n) + ".roff"
print(f"Write file: {filename}")
property_field.to_file(filename, fformat="roff", name=property_name)
print(f"Write file: {scaling_filename}\n")

scaling_field.to_file(scaling_filename, fformat="grdecl", name=scaling_name)

fout.write(
Expand Down Expand Up @@ -501,6 +505,7 @@ def test_localisation_field2(setup_poly_ert):
config = {
"log_level": 3,
"write_scaling_factors": True,
"verify_active": True,
"correlations": [
{
"name": "CORR_GAUSSIAN",
Expand Down Expand Up @@ -598,20 +603,99 @@ def test_localisation_gen_obs(
config = {
"log_level": 2,
"max_gen_obs_size": 1000,
"verify_active": True,
"correlations": [
{
"name": "CORR1",
"obs_group": {
"add": ["POLY_OBS:*"],
# "add": ["POLY_OBS"],
},
"param_group": {
"add": ["*"],
},
},
],
}
with open("local_config_gen_obs.yaml", "w", encoding="utf-8") as fout:
yaml.dump(config, fout)
LocalisationConfigJob(ert).run("local_config_gen_obs.yaml")
expected = {}
expected["CORR1"] = [0, 1, 2, 3, 4]

with open("local_config.yaml", "w", encoding="utf-8") as fout:
ert_local_config = ert.getLocalConfig()
updatestep = ert_local_config.getUpdatestep()
active_indices = {}
for ministep in updatestep:
local_obs_data = ministep.getLocalObsData()
for obs_node in local_obs_data:
active_list_obj = obs_node.getActiveList()
active_indices_list = active_list_obj.getActiveIndexList()
active_indices_list.sort()
active_indices[ministep.name()] = active_indices_list

assert active_indices == expected


@pytest.mark.parametrize(
"obs_group_add1, obs_group_add2, expected",
[
(
["POLY_OBS:0", "POLY_OBS:1", "POLY_OBS:2"],
["POLY_OBS:3", "POLY_OBS:4"],
{"CORR1": [0, 1, 2], "CORR2": [3, 4]},
),
(
["POLY_OBS:4"],
["POLY_OBS:3"],
{
"CORR1": [4],
"CORR2": [3],
},
),
],
)
def test_localisation_gen_obs2(
setup_poly_ert, obs_group_add1, obs_group_add2, expected
):
res_config = ResConfig("poly.ert")
ert = EnKFMain(res_config)
config = {
"log_level": 2,
"max_gen_obs_size": 1000,
"verify_active": True,
"correlations": [
{
"name": "CORR1",
"obs_group": {
"add": obs_group_add1,
},
"param_group": {
"add": ["*"],
},
},
{
"name": "CORR2",
"obs_group": {
"add": obs_group_add2,
},
"param_group": {
"add": ["*"],
},
},
],
}
with open("local_config_gen_obs2.yaml", "w", encoding="utf-8") as fout:
yaml.dump(config, fout)
LocalisationConfigJob(ert).run("local_config.yaml")
LocalisationConfigJob(ert).run("local_config_gen_obs2.yaml")
ert_local_config = ert.getLocalConfig()
updatestep = ert_local_config.getUpdatestep()
active_indices = {}
for ministep in updatestep:
local_obs_data = ministep.getLocalObsData()
for obs_node in local_obs_data:
active_list_obj = obs_node.getActiveList()
active_indices_list = active_list_obj.getActiveIndexList()
active_indices_list.sort()
active_indices[ministep.name()] = active_indices_list

assert active_indices == expected

0 comments on commit fe66e39

Please sign in to comment.