Skip to content

Commit

Permalink
Set first obs as default value in correlation view
Browse files Browse the repository at this point in the history
  • Loading branch information
hnformentin committed Jul 13, 2022
1 parent b350831 commit db68325
Showing 1 changed file with 46 additions and 12 deletions.
58 changes: 46 additions & 12 deletions webviz_ert/controllers/response_correlation_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import plotly.graph_objects as go
import dash

from typing import List, Optional, Dict, Tuple, Any
from typing import List, Optional, Dict, Tuple, Any, Union
from copy import deepcopy
from dash import dcc, html
from dash.development.base_component import Component
Expand Down Expand Up @@ -162,8 +162,8 @@ def update_response_overview_plot(
if not (ensembles and responses and corr_param_resp["response"] in responses):
raise PreventUpdate
selected_response = corr_param_resp["response"]
_plots = []
_obs_plots: List[PlotModel] = []
plots = []
obs_plots: List[PlotModel] = []

loaded_ensembles = [
load_ensemble(parent, ensemble_id) for ensemble_id in ensembles
Expand All @@ -173,8 +173,10 @@ def update_response_overview_plot(
response = ensemble.responses[selected_response]

x_axis = response.axis

if isinstance(x_axis, pd.Index) and x_axis.empty:
continue

if x_axis is not None:
if str(x_axis[0]).isnumeric():
style = deepcopy(assets.ERTSTYLE["response-plot"]["response-index"])
Expand All @@ -185,7 +187,7 @@ def update_response_overview_plot(
style.update({"marker": {"color": ensemble_color}})
style.update({"line": {"color": ensemble_color}})

_plots += [
plots += [
PlotModel(
x_axis=x_axis,
y_axis=data_df[realization],
Expand All @@ -198,23 +200,36 @@ def update_response_overview_plot(

if response.observations:
for obs in response.observations:
_obs_plots.append(_get_observation_plots(obs.data_df()))
obs_plots.append(_get_observation_plots(obs.data_df()))

x_axis_default_observation = _get_first_observation_x(
response.observations[0].data_df()
)
if isinstance(x_axis, pd.Index):
corr_xindex[selected_response] = x_axis.get_loc(
x_axis_default_observation
)
elif isinstance(x_axis, list):
corr_xindex[selected_response] = x_axis.index(
x_axis_default_observation
)

fig = go.Figure()
for plot in _plots:
for plot in plots:
fig.add_trace(plot.repr)

_layout = assets.ERTSTYLE["figure"]["layout"].copy()
_layout.update(dict(showlegend=False))
fig.update_layout(_layout)
layout = assets.ERTSTYLE["figure"]["layout"].copy()
layout.update(dict(showlegend=False))
fig.update_layout(layout)

x_axis_label = axis_label_for_ensemble_response(
loaded_ensembles[0], selected_response
)
fig.update_layout({"xaxis": {"title": {"text": x_axis_label}}})
fig.update_layout(assets.ERTSTYLE["figure"]["layout-value-y-axis-label"])

x_index = corr_xindex.get(selected_response, 0)
default_index = 0
x_index = corr_xindex.get(selected_response, default_index)
if isinstance(x_axis, pd.Index) and not x_axis.empty:
fig.add_shape(
type="line",
Expand All @@ -226,12 +241,27 @@ def update_response_overview_plot(
line=dict(color="rgb(30, 30, 30)", dash="dash", width=3),
)
# draw observations on top
for plot in _obs_plots:
for plot in obs_plots:
fig.add_trace(plot.repr)

fig.update_layout(clickmode="event+select")
return fig

def _get_first_observation_x(obs_data: pd.DataFrame) -> Union[int, str]:
"""
:return: The first x value in the observation data, converted
to type suitable for lookup in the response vector.
"""
if type(obs_data["x_axis"][0]) == str:
return int(obs_data["x_axis"][0])
elif type(obs_data["x_axis"][0]) == pd._libs.tslibs.timestamps.Timestamp:
return str(obs_data["x_axis"][0])
else:
observation_data_type = type(obs_data["x_axis"][0])
raise ValueError(
f"obs_data type should be a str or Timestamp, but it is {observation_data_type}."
)

@app.callback(
[
Output(
Expand Down Expand Up @@ -374,7 +404,11 @@ def update_corr_index(
triggered_id = ctx.triggered[0]["prop_id"].split(".")[0]

if triggered_id == parent.uuid("parameter-selection-store-resp"):
return {response: corr_xindex.get(response, 0) for response in responses}
default_index = 0
return {
response: corr_xindex.get(response, default_index)
for response in responses
}
if click_data:
corr_xindex[corr_param_resp["response"]] = click_data["points"][0][
"pointIndex"
Expand Down

0 comments on commit db68325

Please sign in to comment.