Skip to content

Commit

Permalink
Use new obs vector interface
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed May 31, 2023
1 parent 2f8d513 commit b2a8157
Show file tree
Hide file tree
Showing 10 changed files with 18 additions and 16 deletions.
2 changes: 1 addition & 1 deletion semeio/workflows/ahm_analysis/ahmanalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def run(
if output_dir is not None:
self._reports_dir = output_dir

obs_keys = [o.getKey() for o in self.facade.get_observations()]
obs_keys = list(self.facade.get_observations().obs_vectors.keys())
key_map = _group_observations(self.facade, obs_keys, group_by)

prior_name, target_name = check_names(
Expand Down
2 changes: 1 addition & 1 deletion semeio/workflows/correlated_observations_scaling/cos.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def run(self, job_configuration):
user_config = _insert_default_group(user_config)

obs = self.facade.get_observations()
obs_keys = [o.getKey() for o in obs]
obs_keys = list(obs.obs_vectors.keys())
default_values = _get_default_values(
self.facade.get_alpha(), self.facade.get_std_cutoff()
)
Expand Down
6 changes: 3 additions & 3 deletions semeio/workflows/correlated_observations_scaling/obs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ def _active_list_from_index_list(index_list):


def _data_index_to_obs_index(obs, obs_key, data_index_list):
if obs[obs_key].getImplementationType().name != "GEN_OBS":
if obs[obs_key].observation_type.name != "GEN_OBS":
return data_index_list
if data_index_list is None:
return data_index_list

for timestep in obs[obs_key].getStepList():
node = obs[obs_key].getNode(timestep)
for timestep, node in obs[obs_key].observations.items():
node = obs[obs_key].observations[timestep]
index_map = {node.indices[nr]: nr for nr in range(len(node))}
return [index_map[index] for index in data_index_list]
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ def _update_scaling(obs, scale_factor, obs_list):
"""
for event in obs_list:
obs_vector = obs[event.key]
step_list = obs_vector.getStepList() # List of steps, 1-indexed
for step in step_list:
obs_node = obs_vector.getNode(step)
if obs_vector.getImplementationType().name == "SUMMARY_OBS":
index_list = event.index if event.index else [x - 1 for x in step_list]
index_list = (
event.index
if event.index
else [x - 1 for x in obs_vector.observations.keys()]
)
for step, obs_node in obs_vector.observations.items():
if obs_vector.observation_type.name == "SUMMARY_OBS":
if step - 1 in index_list:
obs_node.std_scaling = scale_factor
else:
Expand Down
2 changes: 1 addition & 1 deletion semeio/workflows/localisation/local_config_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def run(self, *args):
config_dict = local.read_localisation_config(args)

# Get all observations from ert instance
obs_keys = [o.getKey() for o in self.facade.get_observations()]
obs_keys = list(self.facade.get_observations().obs_vectors.keys())

ensemble_config = ert.ensembleConfig()
ert_parameters = local.get_param_from_ert(ensemble_config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class MisfitPreprocessorJob(SemeioScript):
# pylint: disable=method-hidden
def run(self, *args):
config_record = _fetch_config_record(args)
observations = [o.getKey() for o in self.facade.get_observations()]
observations = list(self.facade.get_observations().obs_vectors.keys())
config = assemble_config(config_record, observations)
if config.reports_directory:
self._reports_dir = config.reports_directory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def run(self, *args):
# pylint: disable=method-hidden
# (SemeioScript wraps this run method)

obs_keys = [o.getKey() for o in self.facade.get_observations()]
obs_keys = list(self.facade.get_observations().obs_vectors.keys())

measured_data = self.facade.get_measured_data(obs_keys, ensemble=self.ensemble)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
],
install_requires=[
"ecl",
"ert>=5.0.0rc1",
"ert>=5.0.0rc3",
"configsuite>=0.6",
"numpy",
"pandas>1.3.0",
Expand Down
2 changes: 1 addition & 1 deletion tests/jobs/spearman_correlation_job/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def raising_scaling_job(data):
@pytest.fixture(name="facade")
def fixture_facade():
facade = Mock()
facade.get_observations.return_value = []
facade.get_observations.return_value.obs_vectors = {}
return facade


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_scale_summary_obs(snake_oil_obs, index_list):
scale_observations(snake_oil_obs, 1.2345, [Config("WOPR_OP1_36", index_list)])

obs_vector = snake_oil_obs["WOPR_OP1_36"]
node = obs_vector.getNode(36)
node = obs_vector.observations[36]
assert node.std_scaling == 1.2345, f"index: {36}"


Expand Down

0 comments on commit b2a8157

Please sign in to comment.