-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prescriptor select #10
Changes from all commits
202f993
5e047cd
15c351c
4b865c5
ea04f32
ed24487
7e5ee17
673b368
bb14966
2e308e6
9300b3a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,4 +10,4 @@ suggestion-mode=yes | |
|
||
good-names=X,F,X0 | ||
|
||
fail-under=9.6 | ||
fail-under=9.8 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,22 +7,21 @@ | |
|
||
from app.components.intro import IntroComponent | ||
from app.components.context import ContextComponent | ||
from app.components.filter import FilterComponent | ||
from app.components.outcome import OutcomeComponent | ||
from app.components.parallel import ParallelComponent | ||
from app.components.link import LinkComponent | ||
from app.components.references import ReferencesComponent | ||
from app.utils import EvolutionHandler | ||
|
||
evolution_handler = EvolutionHandler() | ||
metrics = evolution_handler.outcomes.keys() | ||
# The candidates are sorted by rank then distance so the 'best' ones are the first 10 | ||
sample_idxs = list(range(10)) | ||
|
||
intro_component = IntroComponent() | ||
context_component = ContextComponent() | ||
parallel_component = ParallelComponent(evolution_handler.load_initial_metrics_df(), | ||
sample_idxs, | ||
evolution_handler.outcomes) | ||
outcome_component = OutcomeComponent(evolution_handler, sample_idxs) | ||
filter_component = FilterComponent(metrics) | ||
outcome_component = OutcomeComponent(evolution_handler) | ||
link_component = LinkComponent(sample_idxs) | ||
references_component = ReferencesComponent() | ||
|
||
|
@@ -32,7 +31,7 @@ | |
app.title = "Climate Change Decision Making" | ||
|
||
context_component.register_callbacks(app) | ||
parallel_component.register_callbacks(app) | ||
filter_component.register_callbacks(app) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replaced parallel coordinates component with new filter component |
||
outcome_component.register_callbacks(app) | ||
link_component.register_callbacks(app) | ||
|
||
|
@@ -41,7 +40,7 @@ | |
children=[ | ||
intro_component.create_intro_div(), | ||
context_component.create_context_div(), | ||
parallel_component.create_parallel_div(), | ||
filter_component.create_filter_div(), | ||
outcome_component.create_outcomes_div(), | ||
link_component.create_link_div(), | ||
references_component.create_references_div() | ||
|
@@ -50,4 +49,4 @@ | |
|
||
# Run the app | ||
if __name__ == '__main__': | ||
app.run_server(host='0.0.0.0', debug=False, port=4057, use_reloader=False, threaded=True) | ||
app.run_server(host='0.0.0.0', debug=False, port=4057, use_reloader=True, threaded=True) |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Created some classes we can re-use between components |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
""" | ||
Hard-coded combinations of classes using bootstrap akin to custom css classes. | ||
""" | ||
DESC_TEXT = "mb-5 w-75 text-center" | ||
|
||
JUMBOTRON = "p-3 bg-white rounded-5 mx-auto w-75 mb-3" | ||
|
||
CONTAINER = "py-3 d-flex flex-column h-100 align-items-center" | ||
|
||
HEADER = "text-center mb-5" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
""" | ||
Component in charge of filtering out prescriptors by metric. | ||
""" | ||
from dash import html, dcc, Input, Output, State | ||
import dash_bootstrap_components as dbc | ||
import pandas as pd | ||
import plotly.express as px | ||
import plotly.graph_objects as go | ||
|
||
from app.classes import JUMBOTRON, CONTAINER, DESC_TEXT, HEADER | ||
from app.utils import filter_metrics_json | ||
|
||
|
||
class FilterComponent: | ||
""" | ||
Component in charge of filtering out prescriptors by metric specific to each context. | ||
The component stores the metrics to filter with as long as their corresponding HTML ids. | ||
It also keeps track of the parameters that need to be updated for each slider. | ||
""" | ||
def __init__(self, metrics: list[str]): | ||
self.metrics = list(metrics) | ||
self.metric_ids = [metric.replace(" ", "-").replace(".", "_") for metric in self.metrics] | ||
self.updated_params = ["min", "max", "value", "marks"] | ||
|
||
def create_metric_sliders(self): | ||
""" | ||
Creates initial metric sliders and lines them up with their labels. | ||
""" | ||
sliders = [] | ||
for metric in self.metrics: | ||
col_id = metric.replace(" ", "-").replace(".", "_") | ||
slider = dcc.RangeSlider( | ||
id=f"{col_id}-slider", | ||
min=0, | ||
max=1, | ||
value=[0, 1], | ||
marks={0: f"{0:.2f}", 1: f"{1:.2f}"}, | ||
tooltip={"placement": "bottom", "always_visible": True}, | ||
allowCross=False | ||
) | ||
sliders.append(slider) | ||
|
||
div = html.Div( | ||
children=[ | ||
html.Div( | ||
className="d-flex flex-row mb-2", | ||
children=[ | ||
html.Label(self.metrics[i], className="w-25"), # w-25 and flex-grow-1 ensures they line up | ||
html.Div(sliders[i], className="flex-grow-1") | ||
] | ||
) | ||
for i in range(len(self.metrics)) | ||
] | ||
) | ||
return div | ||
|
||
def plot_parallel_coordinates_line(self, | ||
metrics_json: dict[str, list], | ||
metric_ranges: list[tuple[float, float]]) -> go.Figure: | ||
""" | ||
Plots a parallel coordinates plot of the prescriptor metrics. | ||
Starts by plotting "other" so that it's the bottom of the z axis. | ||
Then plots selected candidates in color. | ||
Finally plots the baseline on top. | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved parallel coordinates code into filter.py |
||
fig = go.Figure() | ||
|
||
normalized_df = filter_metrics_json(metrics_json, metric_ranges, normalize=True) | ||
|
||
cand_idxs = list(normalized_df.index)[:-1] # Leave out the baseline | ||
n_special_cands = min(10, len(cand_idxs)) | ||
|
||
showlegend = True | ||
# If "other" is in the cand_idxs, plot all other candidates in lightgray | ||
for cand_idx in cand_idxs[n_special_cands:]: | ||
cand_metrics = normalized_df.loc[cand_idx].values | ||
fig.add_trace(go.Scatter( | ||
x=normalized_df.columns, | ||
y=cand_metrics, | ||
mode='lines', | ||
legendgroup="other", | ||
name="other", | ||
line=dict(color="lightgray"), | ||
showlegend=showlegend | ||
)) | ||
showlegend = False | ||
|
||
# Plot selected candidates besides baseline so it can be on top | ||
for color_idx, cand_idx in enumerate(cand_idxs[:n_special_cands]): | ||
cand_metrics = normalized_df.loc[cand_idx].values | ||
fig.add_trace(go.Scatter( | ||
x=normalized_df.columns, | ||
y=cand_metrics, | ||
mode='lines', | ||
name=str(cand_idx), | ||
line=dict(color=px.colors.qualitative.Plotly[color_idx]) | ||
)) | ||
|
||
baseline_metrics = normalized_df.iloc[-1] | ||
fig.add_trace(go.Scatter( | ||
x=normalized_df.columns, | ||
y=baseline_metrics.values, | ||
mode='lines', | ||
name="baseline", | ||
line=dict(color="black") | ||
)) | ||
|
||
for i in range(len(normalized_df.columns)): | ||
fig.add_vline(x=i, line_color="black") | ||
|
||
full_metrics_df = pd.DataFrame(metrics_json) | ||
normalized_full = (full_metrics_df - full_metrics_df.mean()) / (full_metrics_df.std() + 1e-10) | ||
fig.update_layout( | ||
yaxis_range=[normalized_full.min().min(), normalized_full.max().max()], | ||
title={ | ||
'text': "Normalized Policy Metrics", | ||
'x': 0.5, # Center the title | ||
'xanchor': 'center', # Anchor it at the center | ||
'yanchor': 'top' # Optionally keep it anchored to the top | ||
} | ||
) | ||
|
||
return fig | ||
|
||
def create_filter_div(self): | ||
""" | ||
Creates div showing sliders to choose the range of metric values we want the prescriptors to have. | ||
TODO: Currently the slider tooltips show even while loading which is a bit of an eyesore. | ||
""" | ||
div = html.Div( | ||
className=JUMBOTRON, | ||
children=[ | ||
dbc.Container( | ||
fluid=True, | ||
className=CONTAINER, | ||
children=[ | ||
html.H2("Filter Policies by Desired Behavior", className=HEADER), | ||
html.P("One hundred AI models are trained to create energy policies that make different trade \ | ||
offs in metrics. Use the sliders below to filter the AI generated policies \ | ||
that produce desired behavior. See the results of the filtering in the below sections.", | ||
className=DESC_TEXT), | ||
html.Div( | ||
dcc.Loading( | ||
type="circle", | ||
target_components={"metrics-store": "*"}, | ||
children=[ | ||
self.create_metric_sliders(), | ||
dcc.Store(id="metrics-store") | ||
], | ||
), | ||
className="w-100 mb-5" | ||
), | ||
html.Div( | ||
dbc.Accordion( | ||
dbc.AccordionItem(dcc.Graph(id="parcoords-figure"), title="View Parallel Coordinates"), | ||
start_collapsed=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hide parallel coordinates in accordion dropdown |
||
), | ||
className="w-100" | ||
) | ||
] | ||
) | ||
] | ||
) | ||
return div | ||
|
||
def register_callbacks(self, app): | ||
""" | ||
Registers callbacks related to the filter sliders. | ||
""" | ||
@app.callback( | ||
[Output(f"{metric_id}-slider", param) for metric_id in self.metric_ids for param in self.updated_params], | ||
Input("metrics-store", "data") | ||
) | ||
def update_filter_sliders(metrics_jsonl: list[dict[str, list]]) -> list: | ||
""" | ||
Update the filter slider min/max/value/marks based on the incoming metrics data. The output of this function | ||
is a list of the updated parameters for each slider concatenated. | ||
""" | ||
metrics_df = pd.DataFrame(metrics_jsonl) | ||
total_output = [] | ||
for metric in self.metrics: | ||
metric_output = [] | ||
min_val = metrics_df[metric].min() | ||
max_val = metrics_df[metric].max() | ||
# We need to round down for the min value and round up for the max value | ||
min_val_rounded = min_val // 0.01 / 100 | ||
max_val_rounded = max_val + 0.01 | ||
metric_output = [ | ||
min_val_rounded, | ||
max_val_rounded, | ||
[min_val_rounded, max_val_rounded], | ||
{min_val_rounded: f"{min_val_rounded:.2f}", max_val_rounded: f"{max_val_rounded:.2f}"} | ||
] | ||
total_output.extend(metric_output) | ||
|
||
return total_output | ||
|
||
@app.callback( | ||
Output("parcoords-figure", "figure"), | ||
State("metrics-store", "data"), | ||
[Input(f"{metric_id}-slider", "value") for metric_id in self.metric_ids] | ||
) | ||
def filter_parcoords_figure(metrics_json: dict[str, list], *metric_ranges) -> go.Figure: | ||
""" | ||
Filters parallel coordinates figure based on the metric ranges from the sliders. | ||
""" | ||
return self.plot_parallel_coordinates_line(metrics_json, metric_ranges) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated readme to reflect new project structure