From 202f99386968b1d1894a42fa3cb0c11b959f53ca Mon Sep 17 00:00:00 2001 From: Daniel Young Date: Thu, 5 Sep 2024 16:59:04 -0700 Subject: [PATCH 01/11] Started working on filter component. Need to fix colors, baseline, and others --- app/app.py | 8 ++++-- app/components/filter.py | 52 +++++++++++++++++++++++++++++++++ app/components/outcome.py | 60 +++++++++++++++++++-------------------- app/utils.py | 15 ++++++++++ 4 files changed, 103 insertions(+), 32 deletions(-) create mode 100644 app/components/filter.py diff --git a/app/app.py b/app/app.py index 5228945..ceef664 100644 --- a/app/app.py +++ b/app/app.py @@ -7,8 +7,9 @@ from app.components.intro import IntroComponent from app.components.context import ContextComponent -from app.components.outcome import OutcomeComponent +from app.components.filter import FilterComponent from app.components.parallel import ParallelComponent +from app.components.outcome import OutcomeComponent from app.components.link import LinkComponent from app.components.references import ReferencesComponent from app.utils import EvolutionHandler @@ -19,6 +20,7 @@ intro_component = IntroComponent() context_component = ContextComponent() +filter_component = FilterComponent() parallel_component = ParallelComponent(evolution_handler.load_initial_metrics_df(), sample_idxs, evolution_handler.outcomes) @@ -32,6 +34,7 @@ app.title = "Climate Change Decision Making" context_component.register_callbacks(app) +filter_component.register_callbacks(app) parallel_component.register_callbacks(app) outcome_component.register_callbacks(app) link_component.register_callbacks(app) @@ -41,6 +44,7 @@ children=[ intro_component.create_intro_div(), context_component.create_context_div(), + filter_component.create_filter_div(), parallel_component.create_parallel_div(), outcome_component.create_outcomes_div(), link_component.create_link_div(), @@ -50,4 +54,4 @@ # Run the app if __name__ == '__main__': - app.run_server(host='0.0.0.0', debug=False, port=4057, use_reloader=False, threaded=True) + app.run_server(host='0.0.0.0', debug=False, port=4057, use_reloader=True, threaded=True) diff --git a/app/components/filter.py b/app/components/filter.py new file mode 100644 index 0000000..941a124 --- /dev/null +++ b/app/components/filter.py @@ -0,0 +1,52 @@ +""" +Component in charge of filtering out prescriptors by metric. +""" +from dash import html, dcc, Input, Output +import dash_bootstrap_components as dbc +import pandas as pd + + +class FilterComponent: + def create_metric_sliders(self, metrics_df: pd.DataFrame): + sliders = [] + for col in metrics_df: + col_id = col.replace(" ", "-").replace(".", "_") + min_val = metrics_df[col].min() + max_val = metrics_df[col].max() + marks = {min_val: f"{min_val:.2f}", max_val: f"{max_val:.2f}"} + slider = dcc.RangeSlider( + id=f"{col_id}-slider", + min=min_val, + max=max_val, + value=[min_val, max_val], + marks=marks, + tooltip={"placement": "bottom", "always_visible": True}, + allowCross=False + ) + sliders.append(slider) + return sliders + + def create_filter_div(self): + div = html.Div( + className="p-3 bg-white rounded-5 mx-auto w-75 mb-3", + children=[ + dbc.Container( + fluid=True, + className="py-3", + children=[ + html.H2("Filter AI Models by Desired Metric", className="text-center"), + dcc.Loading(html.Div(id="filter-sliders"), type="circle", target_components="metrics-store") + ] + ) + ] + ) + return div + + def register_callbacks(self, app): + @app.callback( + Output("filter-sliders", "children"), + Input("metrics-store", "data") + ) + def update_filter_sliders(metrics_jsonl): + metrics_df = pd.DataFrame(metrics_jsonl) + return self.create_metric_sliders(metrics_df) diff --git a/app/components/outcome.py b/app/components/outcome.py index 791fd26..aa95161 100644 --- a/app/components/outcome.py +++ b/app/components/outcome.py @@ -1,7 +1,7 @@ """ OutcomeComponent class for the outcome section of the app. """ -from dash import Input, Output, html, dcc +from dash import Input, Output, State, html, dcc import dash_bootstrap_components as dbc import pandas as pd import plotly.express as px @@ -27,6 +27,8 @@ def __init__(self, evolution_handler: EvolutionHandler, all_cand_idxs: list[str] "Government net revenue from adjustments", "Total Primary Energy Demand"] + self.metric_ids = [metric.replace(" ", "-").replace(".", "_") for metric in self.evolution_handler.outcomes] + def plot_outcome_over_time(self, outcome: str, outcomes_jsonl: list[list[dict[str, float]]], cand_idxs: list[int]): """ Plots all the candidates' prescribed actions' outcomes for a given context. @@ -64,7 +66,7 @@ def plot_outcome_over_time(self, outcome: str, outcomes_jsonl: list[list[dict[st y=outcomes_df[outcome], mode='lines', name=str(cand_idx), - line=dict(color=color_map[self.all_cand_idxs.index(cand_idx)]) + line=dict(color=color_map[cand_idxs.index(cand_idx)]) )) if "baseline" in cand_idxs: @@ -127,6 +129,7 @@ def create_outcomes_div(self): children=[ dcc.Store(id="context-actions-store"), dcc.Store(id="outcomes-store"), + dcc.Store(id="metrics-store"), dbc.Row( children=[ dbc.Col( @@ -152,13 +155,15 @@ def register_callbacks(self, app): @app.callback( Output("context-actions-store", "data"), Output("outcomes-store", "data"), + Output("metrics-store", "data"), Output("energy-policy-store", "data"), [Input(f"context-slider-{i}", "value") for i in range(4)] ) - def update_outcomes_store(*context_values): + def update_results_stores(*context_values): """ - When the context sliders are changed, prescribe actions for the context for all candidates and return - the actions and outcomes. + When the context sliders are changed, prescribe actions for the context for all candidates. Then run them + through En-ROADS to get the outcomes. Finally process the outcomes into metrics. Store the context-actions + dicts, outcomes dfs, and metrics df in stores. Also stores the energy policies in the energy-policy-store in link.py. TODO: Make this only load selected candidates. """ @@ -170,6 +175,9 @@ def update_outcomes_store(*context_values): outcomes_jsonl = [outcomes_df[self.plot_outcomes].to_dict("records") for outcomes_df in outcomes_dfs] + metrics_df = self.evolution_handler.outcomes_to_metrics(context_actions_dicts, outcomes_dfs) + metrics_json = metrics_df.to_dict("records") + # Parse energy demand policy # colors = ["brown", "red", "blue", "green", "pink", "lightblue", "orange"] energies = ["coal", "oil", "gas", "renew and hydro", "bio", "nuclear", "new tech"] @@ -177,38 +185,30 @@ def update_outcomes_store(*context_values): selected_dfs = [outcomes_dfs[i] for i in self.all_cand_idxs[:-2]] energy_policy_jsonl = [outcomes_df[demands].to_dict("records") for outcomes_df in selected_dfs] - return context_actions_dicts, outcomes_jsonl, energy_policy_jsonl + return context_actions_dicts, outcomes_jsonl, metrics_json, energy_policy_jsonl @app.callback( Output("outcome-graph-1", "figure"), - Input("outcome-dropdown-1", "value"), - Input("outcomes-store", "data"), - [Input(f"cand-button-{cand_idx}", "outline") for cand_idx in self.all_cand_idxs] - ) - def update_outcomes_plot_1(outcome, outcomes_jsonl, *deselected): - """ - Updates outcome plot when specific outcome is selected or context scatter point is clicked. - """ - cand_idxs = [] - for cand_idx, deselect in zip(self.all_cand_idxs, deselected): - if not deselect: - cand_idxs.append(cand_idx) - fig = self.plot_outcome_over_time(outcome, outcomes_jsonl, cand_idxs) - return fig - - @app.callback( Output("outcome-graph-2", "figure"), + State("metrics-store", "data"), + Input("outcome-dropdown-1", "value"), Input("outcome-dropdown-2", "value"), Input("outcomes-store", "data"), - [Input(f"cand-button-{cand_idx}", "outline") for cand_idx in self.all_cand_idxs] + [Input(f"{metric_id}-slider", "value") for metric_id in self.metric_ids], ) - def update_outcomes_plot_2(outcome, outcomes_jsonl, *deselected): + def update_outcomes_plots(metrics_json, outcome1, outcome2, outcomes_jsonl, *metric_ranges): """ Updates outcome plot when specific outcome is selected or context scatter point is clicked. """ - cand_idxs = [] - for cand_idx, deselect in zip(self.all_cand_idxs, deselected): - if not deselect: - cand_idxs.append(cand_idx) - fig = self.plot_outcome_over_time(outcome, outcomes_jsonl, cand_idxs) - return fig + metrics_df = pd.DataFrame(metrics_json) + metric_names = list(self.evolution_handler.outcomes.keys()) + metric_name_and_range = zip(metric_names, metric_ranges) + for metric_name, metric_range in metric_name_and_range: + metrics_df = metrics_df[metrics_df[metric_name].between(*metric_range)] + + top_10_idxs = metrics_df.index[:10] + + fig1 = self.plot_outcome_over_time(outcome1, outcomes_jsonl, top_10_idxs) + fig2 = self.plot_outcome_over_time(outcome2, outcomes_jsonl, top_10_idxs) + return fig1, fig2 + diff --git a/app/utils.py b/app/utils.py index 65062e5..1af550f 100644 --- a/app/utils.py +++ b/app/utils.py @@ -104,6 +104,21 @@ def context_actions_to_outcomes(self, context_actions_dicts: list[dict[str, floa outcomes_dfs.append(outcomes_df) return outcomes_dfs + + def outcomes_to_metrics(self, + context_actions_dicts: list[dict[str, float]], + outcomes_dfs: list[pd.DataFrame]) -> pd.DataFrame: + """ + Takes parallel lists of context_actions_dicts and outcomes_dfs and processes them into a metrics dict. + All of these metrics dicts are then concatenated into a single DataFrame. + """ + metrics_dicts = [] + for context_actions_dict, outcomes_df in zip(context_actions_dicts, outcomes_dfs): + metrics = self.outcome_manager.process_outcomes(context_actions_dict, outcomes_df) + metrics_dicts.append(metrics) + + metrics_df = pd.DataFrame(metrics_dicts) + return metrics_df def context_baseline_outcomes(self, context_dict: dict[str, float]): """ From 5e047cd6b6fe99c7d38428c1ed145ab520431707 Mon Sep 17 00:00:00 2001 From: Daniel Young Date: Mon, 9 Sep 2024 16:29:31 -0700 Subject: [PATCH 02/11] Added new prescriptor filtering, removed parallel coordinates for now. --- app/app.py | 5 +- app/components/context.py | 17 +++--- app/components/filter.py | 93 ++++++++++++++++++++++++----- app/components/link.py | 78 ++++++++++++++++++------- app/components/outcome.py | 119 ++++++++++++++++++++++---------------- app/utils.py | 2 +- 6 files changed, 215 insertions(+), 99 deletions(-) diff --git a/app/app.py b/app/app.py index ceef664..f8d8fcf 100644 --- a/app/app.py +++ b/app/app.py @@ -15,12 +15,13 @@ from app.utils import EvolutionHandler evolution_handler = EvolutionHandler() +metrics = evolution_handler.outcomes.keys() # The candidates are sorted by rank then distance so the 'best' ones are the first 10 sample_idxs = list(range(10)) intro_component = IntroComponent() context_component = ContextComponent() -filter_component = FilterComponent() +filter_component = FilterComponent(metrics) parallel_component = ParallelComponent(evolution_handler.load_initial_metrics_df(), sample_idxs, evolution_handler.outcomes) @@ -45,7 +46,7 @@ intro_component.create_intro_div(), context_component.create_context_div(), filter_component.create_filter_div(), - parallel_component.create_parallel_div(), + # parallel_component.create_parallel_div(), outcome_component.create_outcomes_div(), link_component.create_link_div(), references_component.create_references_div() diff --git a/app/components/context.py b/app/components/context.py index e21d6c0..d3e2088 100644 --- a/app/components/context.py +++ b/app/components/context.py @@ -127,17 +127,14 @@ def create_context_div(self): children=[ dbc.Container( fluid=True, - className="py-3 d-flex flex-column h-100", + className="py-3 d-flex flex-column h-100 w-70", children=[ - dbc.Row(html.H2("Select a Context Scenario to Optimize For", className="text-center mb-5")), - dbc.Row( - className="mb-2 w-70 text-center mx-auto", - children=[html.P("According to the AR6 climate report: 'The five Shared Socioeconomic \ - Pathways were designed to span a range of challenges to climate change \ - mitigation and adaptation.' Select one of these scenarios by clicking it \ - in the scatter plot below. If desired, manually modify the scenario \ - with the sliders.")] - ), + html.H2("Select a Context Scenario to Optimize For", className="text-center mb-5"), + html.P("According to the AR6 climate report: 'The five Shared Socioeconomic \ + Pathways were designed to span a range of challenges to climate change \ + mitigation and adaptation.' Select one of these scenarios by clicking it \ + in the scatter plot below. If desired, manually modify the scenario \ + with the sliders.", className="mb-2 text-center mx-auto"), dbc.Row( className="flex-grow-1", children=[ diff --git a/app/components/filter.py b/app/components/filter.py index 941a124..295b648 100644 --- a/app/components/filter.py +++ b/app/components/filter.py @@ -7,26 +7,52 @@ class FilterComponent: - def create_metric_sliders(self, metrics_df: pd.DataFrame): + """ + Component in charge of filtering out prescriptors by metric specific to each context. + The component stores the metrics to filter with as long as their corresponding HTML ids. + It also keeps track of the parameters that need to be updated for each slider. + """ + def __init__(self, metrics: list[str]): + self.metrics = list(metrics) + self.metric_ids = [metric.replace(" ", "-").replace(".", "_") for metric in self.metrics] + self.updated_params = ["min", "max", "value", "marks"] + + def create_metric_sliders(self): + """ + Creates initial metric sliders and lines them up with their labels. + """ sliders = [] - for col in metrics_df: - col_id = col.replace(" ", "-").replace(".", "_") - min_val = metrics_df[col].min() - max_val = metrics_df[col].max() - marks = {min_val: f"{min_val:.2f}", max_val: f"{max_val:.2f}"} + for metric in self.metrics: + col_id = metric.replace(" ", "-").replace(".", "_") slider = dcc.RangeSlider( id=f"{col_id}-slider", - min=min_val, - max=max_val, - value=[min_val, max_val], - marks=marks, + min=0, + max=1, + value=[0, 1], + marks={0: f"{0:.2f}", 1: f"{1:.2f}"}, tooltip={"placement": "bottom", "always_visible": True}, allowCross=False ) sliders.append(slider) - return sliders + + div = html.Div( + children=[ + dbc.Row( + children=[ + dbc.Col(html.Label(self.metrics[i]), width=4), + dbc.Col(sliders[i], width=8) + ] + ) + for i in range(len(self.metrics)) + ] + ) + return div def create_filter_div(self): + """ + Creates div showing sliders to choose the range of metric values we want the prescriptors to have. + TODO: Currently the slider tooltips show even while loading which is a bit of an eyesore. + """ div = html.Div( className="p-3 bg-white rounded-5 mx-auto w-75 mb-3", children=[ @@ -34,8 +60,20 @@ def create_filter_div(self): fluid=True, className="py-3", children=[ - html.H2("Filter AI Models by Desired Metric", className="text-center"), - dcc.Loading(html.Div(id="filter-sliders"), type="circle", target_components="metrics-store") + html.H2("Filter AI Models by Desired Metric", className="text-center mb-5"), + html.P("One hundred AI models are trained to create different energy policies that have a \ + diverse range of outcomes. Use the sliders below to filter the models that align with a \ + desired behavior resulting from their automatically generated energy policy. See how \ + this filtering affects the behavior of the policies in the below sections.", + className="text-center"), + dcc.Loading( + type="circle", + target_components={"metrics-store": "*"}, + children=[ + self.create_metric_sliders(), + dcc.Store(id="metrics-store") + ], + ) ] ) ] @@ -43,10 +81,33 @@ def create_filter_div(self): return div def register_callbacks(self, app): + """ + Registers callbacks related to the filter sliders. + """ @app.callback( - Output("filter-sliders", "children"), + [Output(f"{metric_id}-slider", param) for metric_id in self.metric_ids for param in self.updated_params], Input("metrics-store", "data") ) - def update_filter_sliders(metrics_jsonl): + def update_filter_sliders(metrics_jsonl: list[dict[str, list]]) -> list: + """ + Update the filter slider min/max/value/marks based on the incoming metrics data. The output of this function + is a list of the updated parameters for each slider concatenated. + """ metrics_df = pd.DataFrame(metrics_jsonl) - return self.create_metric_sliders(metrics_df) + total_output = [] + for metric in self.metrics: + metric_output = [] + min_val = metrics_df[metric].min() + max_val = metrics_df[metric].max() + # We need to round down for the min value and round up for the max value + min_val_rounded = min_val // 0.01 / 100 + max_val_rounded = max_val + 0.01 + metric_output = [ + min_val_rounded, + max_val_rounded, + [min_val_rounded, max_val_rounded], + {min_val_rounded: f"{min_val_rounded:.2f}", max_val_rounded: f"{max_val_rounded:.2f}"} + ] + total_output.extend(metric_output) + + return total_output diff --git a/app/components/link.py b/app/components/link.py index 9e1ff39..bd7ae23 100644 --- a/app/components/link.py +++ b/app/components/link.py @@ -21,7 +21,7 @@ def __init__(self, cand_idxs: list[int]): self.energies = ["coal", "oil", "gas", "renew and hydro", "bio", "nuclear", "new tech"] self.demands = [f"Primary energy demand of {energy}" for energy in self.energies] - def plot_energy_policy(self, energy_policy_jsonl, cand_idx): + def plot_energy_policy(self, energy_policy_jsonl: list[dict[str, list]], cand_idx: int) -> go.Figure: """ Plots density chart from energy policy. Removes demands that are all 0. @@ -54,15 +54,21 @@ def plot_energy_policy(self, energy_policy_jsonl, cand_idx): legend=dict( orientation="h", yanchor="bottom", - y=1.02, + y=1, xanchor="right", x=1 ), - margin=dict(l=0, r=0, t=0, b=0), + margin=dict(l=0, r=0, t=100, b=0), + title=dict( + text=f"Model {cand_idx} Energy Policy", + x=0.5, + y=0.9, + xanchor="center" + ) ) return fig - def create_button_group(self): + def create_button_group(self) -> html.Div: """ Creates button group to select candidate to link to. """ @@ -93,18 +99,14 @@ def create_link_div(self): children=[ dbc.Container( fluid=True, - className="py-3 d-flex flex-column justify-content-center", + className="py-3 justify-content-center", children=[ html.H2("View Energy Policy and Visualize/Modify Actions in En-ROADS", className="text-center mb-2"), - html.P("Click on a candidate to preview the distribution of energy sources over time due to \ + html.P("Select a candidate to preview the distribution of energy sources over time due to \ its prescribed energy policy. Then click on the link to view the full policy in \ En-ROADS.", className="text-center w-70 mb-2 mx-auto"), - html.Div( - className="w-50 mx-auto", - children=[self.create_button_group()] - ), dcc.Loading( type="circle", children=[ @@ -112,11 +114,29 @@ def create_link_div(self): dcc.Graph(id="energy-policy-graph", className="mb-2") ] ), - dbc.Button("View in En-ROADS", - id="cand-link", - target="_blank", - rel="noopener noreferrer", - className="w-25 mx-auto") + dbc.Row( + className="w-75 mx-auto", + justify="center", + children=[ + dbc.Col( + dcc.Dropdown( + id="cand-link-select", + options=[], + placeholder="Select an AI Model" + ), + width={"size": 3, "offset": 3} + ), + dbc.Col( + dbc.Button( + "View in En-ROADS", + id="cand-link", + target="_blank", + rel="noopener noreferrer", + disabled=True + ), + ) + ] + ), ] ) ] @@ -129,18 +149,34 @@ def register_callbacks(self, app): Input("energy-policy-store", "data"), Input("cand-link-select", "value") ) - def update_energy_policy_graph(energy_policy_jsonl, cand_idx): - return self.plot_energy_policy(energy_policy_jsonl, cand_idx) + def update_energy_policy_graph(energy_policy_jsonl: list[dict[str, list]], cand_idx) -> go.Figure: + if cand_idx is not None: + return self.plot_energy_policy(energy_policy_jsonl, cand_idx) + + # If we have no cand id just return a blank figure asking the user to select a candidate. + fig = go.Figure() + fig.update_layout( + title=dict( + text="Select an AI model to view its policy", + x=0.5, + xanchor="center" + ) + ) + return fig @app.callback( Output("cand-link", "href"), + Output("cand-link", "disabled"), Input("context-actions-store", "data"), Input("cand-link-select", "value") ) - def update_cand_links(context_actions_dicts: list[dict[str, float]], cand_idx): + def update_cand_link(context_actions_dicts: list[dict[str, float]], cand_idx) -> tuple[str, bool]: """ Updates the candidate link when a specific candidate is selected. + Additionally un-disables the button if this is the first time we're selecting a candidate. """ - cand_dict = context_actions_dicts[cand_idx] - link = actions_to_url(cand_dict) - return link + if cand_idx is not None: + cand_dict = context_actions_dicts[cand_idx] + link = actions_to_url(cand_dict) + return link, False + return "", True diff --git a/app/components/outcome.py b/app/components/outcome.py index aa95161..978518e 100644 --- a/app/components/outcome.py +++ b/app/components/outcome.py @@ -35,56 +35,62 @@ def plot_outcome_over_time(self, outcome: str, outcomes_jsonl: list[list[dict[st Also plots the baseline given the context. TODO: Fix colors to match parcoords """ + best_cand_idxs = cand_idxs[:10] outcomes_dfs = [pd.DataFrame(outcomes_json) for outcomes_json in outcomes_jsonl] color_map = px.colors.qualitative.Plotly fig = go.Figure() showlegend = True - if "other" in cand_idxs: - for cand_idx, outcomes_df in enumerate(outcomes_dfs[:-1]): - if cand_idx not in cand_idxs: - outcomes_df["year"] = list(range(1990, 2101)) - # Legend group other so all the other candidates get removed when we click on it. - # Name other because the first other candidate represents them all in the legend - fig.add_trace(go.Scatter( - x=outcomes_df["year"], - y=outcomes_df[outcome], - mode='lines', - legendgroup="other", - name="other", - showlegend=showlegend, - line=dict(color="lightgray") - )) - showlegend = False - - for cand_idx in cand_idxs: - if cand_idx != "baseline" and cand_idx != "other": - outcomes_df = outcomes_dfs[cand_idx] - outcomes_df["year"] = list(range(1990, 2101)) - fig.add_trace(go.Scatter( - x=outcomes_df["year"], - y=outcomes_df[outcome], - mode='lines', - name=str(cand_idx), - line=dict(color=color_map[cand_idxs.index(cand_idx)]) - )) - - if "baseline" in cand_idxs: - baseline_outcomes_df = outcomes_dfs[-1] - baseline_outcomes_df["year"] = list(range(1990, 2101)) + for cand_idx in cand_idxs[10:]: + outcomes_df = outcomes_dfs[cand_idx] + outcomes_df["year"] = list(range(1990, 2101)) + # Legend group other so all the other candidates get removed when we click on it. + # Name other because the first other candidate represents them all in the legend fig.add_trace(go.Scatter( - x=baseline_outcomes_df["year"], - y=baseline_outcomes_df[outcome], + x=outcomes_df["year"], + y=outcomes_df[outcome], mode='lines', - name="baseline", - line=dict(color="black") + legendgroup="other", + name="other", + showlegend=showlegend, + line=dict(color="lightgray") )) + showlegend = False + + for cand_idx in best_cand_idxs: + outcomes_df = outcomes_dfs[cand_idx] + outcomes_df["year"] = list(range(1990, 2101)) + fig.add_trace(go.Scatter( + x=outcomes_df["year"], + y=outcomes_df[outcome], + mode='lines', + name=str(cand_idx), + line=dict(color=color_map[cand_idxs.index(cand_idx)]), + showlegend=True + )) + + baseline_outcomes_df = outcomes_dfs[-1] + baseline_outcomes_df["year"] = list(range(1990, 2101)) + fig.add_trace(go.Scatter( + x=baseline_outcomes_df["year"], + y=baseline_outcomes_df[outcome], + mode='lines', + name="baseline", + line=dict(color="black"), + showlegend=True + )) + + # Standardize the max and min of the y-axis so that the graphs are comparable when we start filtering + # models. + y_min = min([outcomes_df[outcome].min() for outcomes_df in outcomes_dfs]) + y_max = max([outcomes_df[outcome].max() for outcomes_df in outcomes_dfs]) fig.update_layout( title={ "text": f"{outcome} Over Time", "x": 0.5, "xanchor": "center"}, + yaxis_range=[y_min, y_max] ) return fig @@ -129,7 +135,6 @@ def create_outcomes_div(self): children=[ dcc.Store(id="context-actions-store"), dcc.Store(id="outcomes-store"), - dcc.Store(id="metrics-store"), dbc.Row( children=[ dbc.Col( @@ -151,6 +156,18 @@ def create_outcomes_div(self): return div + def filter_metrics_json(self, metrics_json: pd.DataFrame, metric_ranges: list[tuple[float, float]]): + """ + Converts metrics json stored in the metrics store to a DataFrame then filters it based on metric ranges from + sliders. + """ + metrics_df = pd.DataFrame(metrics_json) + metric_names = list(self.evolution_handler.outcomes.keys()) + metric_name_and_range = zip(metric_names, metric_ranges) + for metric_name, metric_range in metric_name_and_range: + metrics_df = metrics_df[metrics_df[metric_name].between(*metric_range)] + return metrics_df + def register_callbacks(self, app): @app.callback( Output("context-actions-store", "data"), @@ -182,8 +199,7 @@ def update_results_stores(*context_values): # colors = ["brown", "red", "blue", "green", "pink", "lightblue", "orange"] energies = ["coal", "oil", "gas", "renew and hydro", "bio", "nuclear", "new tech"] demands = [f"Primary energy demand of {energy}" for energy in energies] - selected_dfs = [outcomes_dfs[i] for i in self.all_cand_idxs[:-2]] - energy_policy_jsonl = [outcomes_df[demands].to_dict("records") for outcomes_df in selected_dfs] + energy_policy_jsonl = [outcomes_df[demands].to_dict("records") for outcomes_df in outcomes_dfs] return context_actions_dicts, outcomes_jsonl, metrics_json, energy_policy_jsonl @@ -200,15 +216,20 @@ def update_outcomes_plots(metrics_json, outcome1, outcome2, outcomes_jsonl, *met """ Updates outcome plot when specific outcome is selected or context scatter point is clicked. """ - metrics_df = pd.DataFrame(metrics_json) - metric_names = list(self.evolution_handler.outcomes.keys()) - metric_name_and_range = zip(metric_names, metric_ranges) - for metric_name, metric_range in metric_name_and_range: - metrics_df = metrics_df[metrics_df[metric_name].between(*metric_range)] - - top_10_idxs = metrics_df.index[:10] + metrics_df = self.filter_metrics_json(metrics_json, metric_ranges) + cand_idxs = list(metrics_df.index) - fig1 = self.plot_outcome_over_time(outcome1, outcomes_jsonl, top_10_idxs) - fig2 = self.plot_outcome_over_time(outcome2, outcomes_jsonl, top_10_idxs) + fig1 = self.plot_outcome_over_time(outcome1, outcomes_jsonl, cand_idxs) + fig2 = self.plot_outcome_over_time(outcome2, outcomes_jsonl, cand_idxs) return fig1, fig2 - + + @app.callback( + Output("cand-link-select", "options"), + State("metrics-store", "data"), + [Input(f"{metric_id}-slider", "value") for metric_id in self.metric_ids] + ) + def update_cand_link_select(metrics_json: dict[str, list], + *metric_ranges: list[tuple[float, float]]) -> list[int]: + metrics_df = self.filter_metrics_json(metrics_json, metric_ranges) + cand_idxs = list(metrics_df.index) + return cand_idxs diff --git a/app/utils.py b/app/utils.py index 1af550f..300feea 100644 --- a/app/utils.py +++ b/app/utils.py @@ -104,7 +104,7 @@ def context_actions_to_outcomes(self, context_actions_dicts: list[dict[str, floa outcomes_dfs.append(outcomes_df) return outcomes_dfs - + def outcomes_to_metrics(self, context_actions_dicts: list[dict[str, float]], outcomes_dfs: list[pd.DataFrame]) -> pd.DataFrame: From 15c351ca3d344be5370660dad62628f514ae0f5b Mon Sep 17 00:00:00 2001 From: Daniel Young Date: Mon, 9 Sep 2024 16:40:49 -0700 Subject: [PATCH 03/11] Updated readme --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 6d56c00..f433fe0 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,15 @@ # Decision Making for Climate Change -Immediate action is required to combat climate change. The technology behind Cognizant NeuroAI brings automatic decision-making to the En-ROADS platform, a powerful climate change simulator. A decision-maker can be ready for any scenario: choosing an automatically generated policy that suits their needs best, with the ability to manually modify the policy and see its results. This tool is brought together under Project Resilience, a United Nations initiative to use AI for good. +Immediate action is required to combat climate change. The technology behind [Cognizant NeuroAI](https://evolution.ml/) brings automatic decision-making to the En-ROADS platform, a powerful climate change simulator. A decision-maker can be ready for any scenario: choosing an automatically generated policy that suits their needs best, with the ability to manually modify the policy and see its results. This tool is brought together under Project Resilience, a United Nations initiative to use AI for good. ## En-ROADS Wrapper -En-ROADS is a climate change simulator developed by Climate Interactive. We have created a wrapper around the SDK to make it simple to use in a Python application which can be found in `enroadspy`. See `enroads_runner.py` for the main class that runs the SDK. The SDK is not included in this repository and must be requested from Climate Interactive. +En-ROADS is a climate change simulator developed by Climate Interactive. We have created a wrapper around the SDK to make it simple to use in a Python application which can be found in `enroadspy`. See `enroads_runner.py` for the main class that runs the SDK. The SDK is not included in this repository and must be requested from [Climate Interactive](https://www.climateinteractive.org/). -The input data format is a crazy long JSON object which I copied out of the source code, pasted into `inputSpecs.py`, and parsed into `inputSpecs.jsonl`. This format is used by the rest of the code. +The input data format is a crazy long JSON object which I copied out of the source code, pasted into `inputSpecs.py`, and parsed into `inputSpecs.jsonl`. This format is used by the rest of the repository. ### Installation -Download the en-roads zip file, place it in the root folder of the repository, and unzip it. +Run `pip install -r requirements.txt` to install the required packages. Then run `python -m enroadspy.download_sdk` to download the SDK. In order to download the SDK environment variables must be set which can be requested online from Climate Interactive. ## Evolution From 4b865c5be25fb7ccbf3e36f2c7833cc73c407c51 Mon Sep 17 00:00:00 2001 From: Daniel Young Date: Tue, 10 Sep 2024 14:59:08 -0700 Subject: [PATCH 04/11] Did some rewording and added toggleable parallel coordinates plot --- app/app.py | 12 ++-- app/components/filter.py | 135 ++++++++++++++++++++++++++++++++++++-- app/components/link.py | 23 +++---- app/components/outcome.py | 41 ++++++------ app/utils.py | 21 ++++++ 5 files changed, 187 insertions(+), 45 deletions(-) diff --git a/app/app.py b/app/app.py index f8d8fcf..72cea2d 100644 --- a/app/app.py +++ b/app/app.py @@ -8,7 +8,7 @@ from app.components.intro import IntroComponent from app.components.context import ContextComponent from app.components.filter import FilterComponent -from app.components.parallel import ParallelComponent +# from app.components.parallel import ParallelComponent from app.components.outcome import OutcomeComponent from app.components.link import LinkComponent from app.components.references import ReferencesComponent @@ -22,10 +22,10 @@ intro_component = IntroComponent() context_component = ContextComponent() filter_component = FilterComponent(metrics) -parallel_component = ParallelComponent(evolution_handler.load_initial_metrics_df(), - sample_idxs, - evolution_handler.outcomes) -outcome_component = OutcomeComponent(evolution_handler, sample_idxs) +# parallel_component = ParallelComponent(evolution_handler.load_initial_metrics_df(), +# sample_idxs, +# evolution_handler.outcomes) +outcome_component = OutcomeComponent(evolution_handler) link_component = LinkComponent(sample_idxs) references_component = ReferencesComponent() @@ -36,7 +36,7 @@ context_component.register_callbacks(app) filter_component.register_callbacks(app) -parallel_component.register_callbacks(app) +# parallel_component.register_callbacks(app) outcome_component.register_callbacks(app) link_component.register_callbacks(app) diff --git a/app/components/filter.py b/app/components/filter.py index 295b648..3f44c82 100644 --- a/app/components/filter.py +++ b/app/components/filter.py @@ -1,9 +1,13 @@ """ Component in charge of filtering out prescriptors by metric. """ -from dash import html, dcc, Input, Output +from dash import html, dcc, Input, Output, State import dash_bootstrap_components as dbc import pandas as pd +import plotly.express as px +import plotly.graph_objects as go + +from app.utils import filter_metrics_json class FilterComponent: @@ -38,6 +42,7 @@ def create_metric_sliders(self): div = html.Div( children=[ dbc.Row( + className="mb-2", children=[ dbc.Col(html.Label(self.metrics[i]), width=4), dbc.Col(sliders[i], width=8) @@ -47,6 +52,75 @@ def create_metric_sliders(self): ] ) return div + + def plot_parallel_coordinates_line(self, + metrics_json: dict[str, list], + metric_ranges: list[tuple[float, float]]) -> go.Figure: + """ + NOTE: This is legacy code that may be brought back in later for a user toggle. + Plots a parallel coordinates plot of the prescriptor metrics. + Starts by plotting "other" if selected so that it's the bottom of the z axis. + Then plots selected candidates in color. + Finally plots the baseline on top if selected. + """ + fig = go.Figure() + + normalized_df = filter_metrics_json(metrics_json, metric_ranges, normalize=True) + + cand_idxs = list(normalized_df.index)[:-1] # Leave out the baseline + n_special_cands = min(10, len(cand_idxs)) + + showlegend = True + # If "other" is in the cand_idxs, plot all other candidates in lightgray + for cand_idx in cand_idxs[n_special_cands:]: + cand_metrics = normalized_df.loc[cand_idx].values + fig.add_trace(go.Scatter( + x=normalized_df.columns, + y=cand_metrics, + mode='lines', + legendgroup="other", + name="other", + line=dict(color="lightgray"), + showlegend=showlegend + )) + showlegend = False + + # Plot selected candidates besides baseline so it can be on top + for color_idx, cand_idx in enumerate(cand_idxs[:n_special_cands]): + cand_metrics = normalized_df.loc[cand_idx].values + fig.add_trace(go.Scatter( + x=normalized_df.columns, + y=cand_metrics, + mode='lines', + name=str(cand_idx), + line=dict(color=px.colors.qualitative.Plotly[color_idx]) + )) + + baseline_metrics = normalized_df.iloc[-1] + fig.add_trace(go.Scatter( + x=normalized_df.columns, + y=baseline_metrics.values, + mode='lines', + name="baseline", + line=dict(color="black") + )) + + for i in range(len(normalized_df.columns)): + fig.add_vline(x=i, line_color="black") + + full_metrics_df = pd.DataFrame(metrics_json) + normalized_full = (full_metrics_df - full_metrics_df.mean()) / (full_metrics_df.std() + 1e-10) + fig.update_layout( + yaxis_range=[normalized_full.min().min(), normalized_full.max().max()], + title={ + 'text': "Normalized Policy Metrics", + 'x': 0.5, # Center the title + 'xanchor': 'center', # Anchor it at the center + 'yanchor': 'top' # Optionally keep it anchored to the top + } + ) + + return fig def create_filter_div(self): """ @@ -60,11 +134,12 @@ def create_filter_div(self): fluid=True, className="py-3", children=[ - html.H2("Filter AI Models by Desired Metric", className="text-center mb-5"), - html.P("One hundred AI models are trained to create different energy policies that have a \ - diverse range of outcomes. Use the sliders below to filter the models that align with a \ - desired behavior resulting from their automatically generated energy policy. See how \ - this filtering affects the behavior of the policies in the below sections.", + html.H2("Filter Policies by Desired Behavior", className="text-center mb-5"), + html.P("One hundred AI models are trained to create different energy policies that make trade \ + offs in metrics. Use the sliders below to filter the AI generated policies \ + that produce desired behavior resulting from their automatically generated energy \ + policy. See how this filtering affects the behavior of the policies in the below \ + sections.", className="text-center"), dcc.Loading( type="circle", @@ -73,6 +148,30 @@ def create_filter_div(self): self.create_metric_sliders(), dcc.Store(id="metrics-store") ], + ), + html.Div( + className="d-flex flex-column align-items-center", + children=[ + dbc.Button( + "Toggle Detailed Select", + id="parcoords-collapse-button", + className="mb-3", + color="secondary", + outline=True, + n_clicks=0 + ), + dbc.Collapse( + children=[ + dbc.Card( + dcc.Graph(id="parcoords-figure"), + color="secondary" + ) + ], + id="parcoords-collapse", + className="bg-gray rounded-5", + is_open=False + ) + ] ) ] ) @@ -111,3 +210,27 @@ def update_filter_sliders(metrics_jsonl: list[dict[str, list]]) -> list: total_output.extend(metric_output) return total_output + + @app.callback( + Output("parcoords-collapse", "is_open"), + Input("parcoords-collapse-button", "n_clicks"), + State("parcoords-collapse", "is_open") + ) + def toggle_parcoords_collapse(n, is_open): + """ + Toggles collapse. From dbc documentation. + """ + if n: + return not is_open + return is_open + + @app.callback( + Output("parcoords-figure", "figure"), + State("metrics-store", "data"), + [Input(f"{metric_id}-slider", "value") for metric_id in self.metric_ids] + ) + def filter_parcoords_figure(metrics_json: dict[str, list], *metric_ranges) -> go.Figure: + """ + Filters parallel coordinates figure based on the metric ranges from the sliders. + """ + return self.plot_parallel_coordinates_line(metrics_json, metric_ranges) diff --git a/app/components/link.py b/app/components/link.py index bd7ae23..b5d4724 100644 --- a/app/components/link.py +++ b/app/components/link.py @@ -60,7 +60,7 @@ def plot_energy_policy(self, energy_policy_jsonl: list[dict[str, list]], cand_id ), margin=dict(l=0, r=0, t=100, b=0), title=dict( - text=f"Model {cand_idx} Energy Policy", + text=f"Policy {cand_idx} Energy Source Distribution Over Time", x=0.5, y=0.9, xanchor="center" @@ -101,11 +101,10 @@ def create_link_div(self): fluid=True, className="py-3 justify-content-center", children=[ - html.H2("View Energy Policy and Visualize/Modify Actions in En-ROADS", - className="text-center mb-2"), - html.P("Select a candidate to preview the distribution of energy sources over time due to \ - its prescribed energy policy. Then click on the link to view the full policy in \ - En-ROADS.", + html.H2("View Policy Energy Sources and Explore Policy in En-ROADS", + className="text-center mb-5"), + html.P("Select a policy to preview its resulting distribution of energy sources over time. \ + Then click on the link to explore and fine-tune the policy in En-ROADS.", className="text-center w-70 mb-2 mx-auto"), dcc.Loading( type="circle", @@ -116,24 +115,26 @@ def create_link_div(self): ), dbc.Row( className="w-75 mx-auto", - justify="center", + align="center", children=[ dbc.Col( dcc.Dropdown( id="cand-link-select", options=[], - placeholder="Select an AI Model" + placeholder="Select a policy" ), - width={"size": 3, "offset": 3} + width={"size": 3, "offset": 1} ), dbc.Col( dbc.Button( - "View in En-ROADS", + "Explore & Fine-Tune Policy in En-ROADS", id="cand-link", target="_blank", rel="noopener noreferrer", + size="lg", disabled=True ), + width={"size": 4} ) ] ), @@ -157,7 +158,7 @@ def update_energy_policy_graph(energy_policy_jsonl: list[dict[str, list]], cand_ fig = go.Figure() fig.update_layout( title=dict( - text="Select an AI model to view its policy", + text="Select a policy to view its energy source distribution", x=0.5, xanchor="center" ) diff --git a/app/components/outcome.py b/app/components/outcome.py index 978518e..8037e88 100644 --- a/app/components/outcome.py +++ b/app/components/outcome.py @@ -6,7 +6,8 @@ import pandas as pd import plotly.express as px import plotly.graph_objects as go -from app.utils import EvolutionHandler + +from app.utils import EvolutionHandler, filter_metrics_json class OutcomeComponent(): @@ -15,9 +16,8 @@ class OutcomeComponent(): Has drop downs to allow the user to select which outcomes they want to see. TODO: Make it so we only load the selected prescriptors. """ - def __init__(self, evolution_handler: EvolutionHandler, all_cand_idxs: list[str]): + def __init__(self, evolution_handler: EvolutionHandler): self.evolution_handler = evolution_handler - self.all_cand_idxs = all_cand_idxs + ["baseline", "other"] self.context_cols = ["_long_term_gdp_per_capita_rate", "_near_term_gdp_per_capita_rate", "_transition_time_to_reach_long_term_gdp_per_capita_rate", @@ -106,7 +106,7 @@ def create_outcomes_div(self): fluid=True, className="py-3", children=[ - dbc.Row(html.H2("Outcomes of Prescribed Actions", className="text-center mb-5")), + dbc.Row(html.H2("Outcomes for Selected Policies", className="text-center mb-5")), dbc.Row( children=[ dbc.Col( @@ -156,19 +156,10 @@ def create_outcomes_div(self): return div - def filter_metrics_json(self, metrics_json: pd.DataFrame, metric_ranges: list[tuple[float, float]]): + def register_callbacks(self, app): """ - Converts metrics json stored in the metrics store to a DataFrame then filters it based on metric ranges from - sliders. + Registers callbacks relating to the outcomes section of the app. """ - metrics_df = pd.DataFrame(metrics_json) - metric_names = list(self.evolution_handler.outcomes.keys()) - metric_name_and_range = zip(metric_names, metric_ranges) - for metric_name, metric_range in metric_name_and_range: - metrics_df = metrics_df[metrics_df[metric_name].between(*metric_range)] - return metrics_df - - def register_callbacks(self, app): @app.callback( Output("context-actions-store", "data"), Output("outcomes-store", "data"), @@ -184,19 +175,22 @@ def update_results_stores(*context_values): Also stores the energy policies in the energy-policy-store in link.py. TODO: Make this only load selected candidates. """ + # Prescribe actions for all candidates via. torch context_dict = dict(zip(self.context_cols, context_values)) context_actions_dicts = self.evolution_handler.prescribe_all(context_dict) - outcomes_dfs = self.evolution_handler.context_actions_to_outcomes(context_actions_dicts) - baseline_outcomes_df = self.evolution_handler.context_baseline_outcomes(context_dict) - outcomes_dfs.append(baseline_outcomes_df) + # Attach baseline (no actions) + context_actions_dicts.append(dict(**context_dict)) + + # Run En-ROADS on all candidates and save as jsonl + outcomes_dfs = self.evolution_handler.context_actions_to_outcomes(context_actions_dicts) outcomes_jsonl = [outcomes_df[self.plot_outcomes].to_dict("records") for outcomes_df in outcomes_dfs] + # Process outcomes into metrics and save metrics_df = self.evolution_handler.outcomes_to_metrics(context_actions_dicts, outcomes_dfs) metrics_json = metrics_df.to_dict("records") - # Parse energy demand policy - # colors = ["brown", "red", "blue", "green", "pink", "lightblue", "orange"] + # Parse energy demand policy from outcomes for use in link.py energies = ["coal", "oil", "gas", "renew and hydro", "bio", "nuclear", "new tech"] demands = [f"Primary energy demand of {energy}" for energy in energies] energy_policy_jsonl = [outcomes_df[demands].to_dict("records") for outcomes_df in outcomes_dfs] @@ -216,7 +210,7 @@ def update_outcomes_plots(metrics_json, outcome1, outcome2, outcomes_jsonl, *met """ Updates outcome plot when specific outcome is selected or context scatter point is clicked. """ - metrics_df = self.filter_metrics_json(metrics_json, metric_ranges) + metrics_df = filter_metrics_json(metrics_json, metric_ranges) cand_idxs = list(metrics_df.index) fig1 = self.plot_outcome_over_time(outcome1, outcomes_jsonl, cand_idxs) @@ -230,6 +224,9 @@ def update_outcomes_plots(metrics_json, outcome1, outcome2, outcomes_jsonl, *met ) def update_cand_link_select(metrics_json: dict[str, list], *metric_ranges: list[tuple[float, float]]) -> list[int]: - metrics_df = self.filter_metrics_json(metrics_json, metric_ranges) + """ + Updates the available candidates in the link dropdown based on metric ranges. + """ + metrics_df = filter_metrics_json(metrics_json, metric_ranges) cand_idxs = list(metrics_df.index) return cand_idxs diff --git a/app/utils.py b/app/utils.py index 300feea..34a958f 100644 --- a/app/utils.py +++ b/app/utils.py @@ -13,6 +13,27 @@ from evolution.outcomes.outcome_manager import OutcomeManager +def filter_metrics_json(metrics_json: dict[str, list], + metric_ranges: list[tuple[float, float]], + normalize=False) -> pd.DataFrame: + """ + Converts metrics json stored in the metrics store to a DataFrame then filters it based on metric ranges from + sliders. + """ + metrics_df = pd.DataFrame(metrics_json) + mu = metrics_df.mean() + sigma = metrics_df.std() + metric_names = metrics_df.columns + metric_name_and_range = zip(metric_names, metric_ranges) + for metric_name, metric_range in metric_name_and_range: + # Never filter out the baseline + condition = (metrics_df[metric_name].between(*metric_range)) | (metrics_df.index == metrics_df.index[-1]) + metrics_df = metrics_df[condition] + if normalize: + metrics_df = (metrics_df - mu) / (sigma + 1e-10) + return metrics_df + + class EvolutionHandler(): """ Handles evolution results and running of prescriptors for the app. From ea04f32ce5e824cf0b50e35ea95678415a355170 Mon Sep 17 00:00:00 2001 From: Daniel Young Date: Tue, 10 Sep 2024 16:44:44 -0700 Subject: [PATCH 05/11] Started cleaning up bootstrap. Now outcomes look funny --- app/classes.py | 10 +++++ app/components/context.py | 11 ++--- app/components/filter.py | 88 +++++++++++++++++---------------------- app/components/outcome.py | 53 +++++++++++------------ 4 files changed, 78 insertions(+), 84 deletions(-) create mode 100644 app/classes.py diff --git a/app/classes.py b/app/classes.py new file mode 100644 index 0000000..cc52e47 --- /dev/null +++ b/app/classes.py @@ -0,0 +1,10 @@ +""" +Hard-coded combinations of classes using bootstrap akin to custom css classes. +""" +DESC_TEXT = "mb-5 w-75 text-center" + +JUMBOTRON = "p-3 bg-white rounded-5 mx-auto w-75 mb-3" + +CONTAINER = "py-3 d-flex flex-column h-100 align-items-center" + +HEADER = "text-center mb-5" diff --git a/app/components/context.py b/app/components/context.py index d3e2088..10f0f58 100644 --- a/app/components/context.py +++ b/app/components/context.py @@ -8,6 +8,7 @@ import plotly.express as px import plotly.graph_objects as go +from app.classes import JUMBOTRON, CONTAINER, DESC_TEXT, HEADER from enroadspy import load_input_specs @@ -123,18 +124,18 @@ def create_context_div(self): ) div = html.Div( - className="p-3 bg-white rounded-5 mx-auto w-75 mb-3", + className=JUMBOTRON, children=[ dbc.Container( fluid=True, - className="py-3 d-flex flex-column h-100 w-70", + className=CONTAINER, children=[ - html.H2("Select a Context Scenario to Optimize For", className="text-center mb-5"), + html.H2("Select a Context Scenario to Optimize For", className=HEADER), html.P("According to the AR6 climate report: 'The five Shared Socioeconomic \ Pathways were designed to span a range of challenges to climate change \ - mitigation and adaptation.' Select one of these scenarios by clicking it \ + mitigation and adaptation'. Select one of these scenarios by clicking it \ in the scatter plot below. If desired, manually modify the scenario \ - with the sliders.", className="mb-2 text-center mx-auto"), + with the sliders.", className=DESC_TEXT), dbc.Row( className="flex-grow-1", children=[ diff --git a/app/components/filter.py b/app/components/filter.py index 3f44c82..0d48614 100644 --- a/app/components/filter.py +++ b/app/components/filter.py @@ -7,6 +7,7 @@ import plotly.express as px import plotly.graph_objects as go +from app.classes import JUMBOTRON, CONTAINER, DESC_TEXT, HEADER from app.utils import filter_metrics_json @@ -41,11 +42,11 @@ def create_metric_sliders(self): div = html.Div( children=[ - dbc.Row( - className="mb-2", + html.Div( + className="d-flex flex-row", children=[ - dbc.Col(html.Label(self.metrics[i]), width=4), - dbc.Col(sliders[i], width=8) + html.Label(self.metrics[i], className="w-25"), # w-25 and flex-grow-1 ensures they line up + html.Div(sliders[i], className="flex-grow-1") ] ) for i in range(len(self.metrics)) @@ -128,50 +129,37 @@ def create_filter_div(self): TODO: Currently the slider tooltips show even while loading which is a bit of an eyesore. """ div = html.Div( - className="p-3 bg-white rounded-5 mx-auto w-75 mb-3", + className=JUMBOTRON, children=[ dbc.Container( fluid=True, - className="py-3", + className=CONTAINER, children=[ - html.H2("Filter Policies by Desired Behavior", className="text-center mb-5"), - html.P("One hundred AI models are trained to create different energy policies that make trade \ + html.H2("Filter Policies by Desired Behavior", className=HEADER), + html.P("One hundred AI models are trained to create energy policies that make different trade \ offs in metrics. Use the sliders below to filter the AI generated policies \ that produce desired behavior resulting from their automatically generated energy \ policy. See how this filtering affects the behavior of the policies in the below \ sections.", - className="text-center"), - dcc.Loading( - type="circle", - target_components={"metrics-store": "*"}, - children=[ - self.create_metric_sliders(), - dcc.Store(id="metrics-store") - ], + className=DESC_TEXT), + html.Div( + dcc.Loading( + type="circle", + target_components={"metrics-store": "*"}, + children=[ + self.create_metric_sliders(), + dcc.Store(id="metrics-store") + ], + ), + className="w-100 mb-5" ), html.Div( - className="d-flex flex-column align-items-center", - children=[ - dbc.Button( - "Toggle Detailed Select", - id="parcoords-collapse-button", - className="mb-3", - color="secondary", - outline=True, - n_clicks=0 - ), - dbc.Collapse( - children=[ - dbc.Card( - dcc.Graph(id="parcoords-figure"), - color="secondary" - ) - ], - id="parcoords-collapse", - className="bg-gray rounded-5", - is_open=False - ) - ] + dbc.Accordion( + dbc.AccordionItem(dcc.Graph(id="parcoords-figure"), title="View Parallel Coordinates"), + start_collapsed=True, + flush=True + ), + className="w-100" ) ] ) @@ -211,18 +199,18 @@ def update_filter_sliders(metrics_jsonl: list[dict[str, list]]) -> list: return total_output - @app.callback( - Output("parcoords-collapse", "is_open"), - Input("parcoords-collapse-button", "n_clicks"), - State("parcoords-collapse", "is_open") - ) - def toggle_parcoords_collapse(n, is_open): - """ - Toggles collapse. From dbc documentation. - """ - if n: - return not is_open - return is_open + # @app.callback( + # Output("parcoords-collapse", "is_open"), + # Input("parcoords-collapse-button", "n_clicks"), + # State("parcoords-collapse", "is_open") + # ) + # def toggle_parcoords_collapse(n, is_open): + # """ + # Toggles collapse. From dbc documentation. + # """ + # if n: + # return not is_open + # return is_open @app.callback( Output("parcoords-figure", "figure"), diff --git a/app/components/outcome.py b/app/components/outcome.py index 8037e88..098816b 100644 --- a/app/components/outcome.py +++ b/app/components/outcome.py @@ -7,6 +7,7 @@ import plotly.express as px import plotly.graph_objects as go +from app.classes import JUMBOTRON, CONTAINER, DESC_TEXT, HEADER from app.utils import EvolutionHandler, filter_metrics_json @@ -100,32 +101,31 @@ def create_outcomes_div(self): is updated. Otherwise, we have individual loads for each graph. """ div = html.Div( - className="p-3 bg-white rounded-5 mx-auto w-75 mb-3", + className=JUMBOTRON, children=[ dbc.Container( fluid=True, - className="py-3", + className=CONTAINER, children=[ - dbc.Row(html.H2("Outcomes for Selected Policies", className="text-center mb-5")), - dbc.Row( + html.H2("Outcomes for Selected Policies", className=HEADER), + html.Div( + className="d-flex flex-row w-100", children=[ - dbc.Col( - children=[ - dcc.Dropdown( - id="outcome-dropdown-1", - options=self.plot_outcomes, - value=self.plot_outcomes[0] - ) - ] + html.Div( + dcc.Dropdown( + id="outcome-dropdown-1", + options=self.plot_outcomes, + value=self.plot_outcomes[0] + ), + className="flex-fill" ), - dbc.Col( - children=[ - dcc.Dropdown( - id="outcome-dropdown-2", - options=self.plot_outcomes, - value=self.plot_outcomes[1] - ) - ] + html.Div( + dcc.Dropdown( + id="outcome-dropdown-2", + options=self.plot_outcomes, + value=self.plot_outcomes[1] + ), + className="flex-fill" ) ] ), @@ -135,16 +135,11 @@ def create_outcomes_div(self): children=[ dcc.Store(id="context-actions-store"), dcc.Store(id="outcomes-store"), - dbc.Row( + html.Div( + className="d-flex flex-row w-100", children=[ - dbc.Col( - dcc.Graph(id="outcome-graph-1"), - width=6 - ), - dbc.Col( - dcc.Graph(id="outcome-graph-2"), - width=6 - ) + dcc.Graph(id="outcome-graph-1"), + dcc.Graph(id="outcome-graph-2") ] ) ] From ed244877ca52e55f23e3b8b4d1b7001841116cd2 Mon Sep 17 00:00:00 2001 From: Daniel Young Date: Wed, 11 Sep 2024 09:28:00 -0700 Subject: [PATCH 06/11] Fixed outcome formatting by using grid within flex --- app/components/outcome.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/app/components/outcome.py b/app/components/outcome.py index 098816b..ae28840 100644 --- a/app/components/outcome.py +++ b/app/components/outcome.py @@ -135,11 +135,11 @@ def create_outcomes_div(self): children=[ dcc.Store(id="context-actions-store"), dcc.Store(id="outcomes-store"), - html.Div( - className="d-flex flex-row w-100", + dbc.Row( + className="g-0", children=[ - dcc.Graph(id="outcome-graph-1"), - dcc.Graph(id="outcome-graph-2") + dbc.Col(dcc.Graph(id="outcome-graph-1"), width=6), + dbc.Col(dcc.Graph(id="outcome-graph-2"), width=6) ] ) ] From 7e5ee1708109f5d56fa8dc3bb8e77ec15581e184 Mon Sep 17 00:00:00 2001 From: Daniel Young Date: Wed, 11 Sep 2024 10:39:17 -0700 Subject: [PATCH 07/11] finished bootstrap-ifying the page --- app/components/link.py | 12 ++++++++---- app/components/outcome.py | 2 +- app/components/references.py | 8 +++++--- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/app/components/link.py b/app/components/link.py index b5d4724..1e220c9 100644 --- a/app/components/link.py +++ b/app/components/link.py @@ -6,6 +6,7 @@ import pandas as pd import plotly.graph_objects as go +from app.classes import JUMBOTRON, CONTAINER, DESC_TEXT, HEADER from enroadspy.generate_url import actions_to_url @@ -95,17 +96,17 @@ def create_link_div(self): TODO: Make link unclickable while outcomes are loading. """ div = html.Div( - className="p-3 bg-white rounded-5 mx-auto w-75 mb-3", + className=JUMBOTRON, children=[ dbc.Container( fluid=True, - className="py-3 justify-content-center", + className=CONTAINER, children=[ html.H2("View Policy Energy Sources and Explore Policy in En-ROADS", - className="text-center mb-5"), + className=HEADER), html.P("Select a policy to preview its resulting distribution of energy sources over time. \ Then click on the link to explore and fine-tune the policy in En-ROADS.", - className="text-center w-70 mb-2 mx-auto"), + className=DESC_TEXT), dcc.Loading( type="circle", children=[ @@ -145,6 +146,9 @@ def create_link_div(self): return div def register_callbacks(self, app): + """ + Registers callbacks for the links component. + """ @app.callback( Output("energy-policy-graph", "figure"), Input("energy-policy-store", "data"), diff --git a/app/components/outcome.py b/app/components/outcome.py index ae28840..2b18a35 100644 --- a/app/components/outcome.py +++ b/app/components/outcome.py @@ -7,7 +7,7 @@ import plotly.express as px import plotly.graph_objects as go -from app.classes import JUMBOTRON, CONTAINER, DESC_TEXT, HEADER +from app.classes import JUMBOTRON, CONTAINER, HEADER from app.utils import EvolutionHandler, filter_metrics_json diff --git a/app/components/references.py b/app/components/references.py index fb9aa1f..0a3ed5c 100644 --- a/app/components/references.py +++ b/app/components/references.py @@ -4,6 +4,8 @@ from dash import html import dash_bootstrap_components as dbc +from app.classes import JUMBOTRON, CONTAINER, HEADER + class ReferencesComponent(): """ @@ -14,13 +16,13 @@ def create_references_div(self): Creates div displaying references """ div = html.Div( - className="p-3 bg-white rounded-5 mx-auto w-75 mb-3", + className=JUMBOTRON, children=[ dbc.Container( fluid=True, - className="py-3", + className=CONTAINER[:-18], # Left-align our references children=[ - html.H2("References", className="text-center mb-2"), + html.H2("References", className=HEADER), html.P([ "For more info about Project Resilience, visit the ", html.A("United Nations ITU Page", From 673b3683043184f4449b59ebd65d5f8d314f6cc7 Mon Sep 17 00:00:00 2001 From: Daniel Young Date: Wed, 11 Sep 2024 10:40:54 -0700 Subject: [PATCH 08/11] Spread out sliders for filter a bit --- app/components/filter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/components/filter.py b/app/components/filter.py index 0d48614..40dac0d 100644 --- a/app/components/filter.py +++ b/app/components/filter.py @@ -43,7 +43,7 @@ def create_metric_sliders(self): div = html.Div( children=[ html.Div( - className="d-flex flex-row", + className="d-flex flex-row mb-2", children=[ html.Label(self.metrics[i], className="w-25"), # w-25 and flex-grow-1 ensures they line up html.Div(sliders[i], className="flex-grow-1") From bb149668cb12fccf7dd7ecb3187bbe95cbef6bd3 Mon Sep 17 00:00:00 2001 From: Daniel Young Date: Wed, 11 Sep 2024 10:53:32 -0700 Subject: [PATCH 09/11] Miscellaneous cleanup and linting --- app/app.py | 6 ------ app/components/filter.py | 21 +++------------------ app/components/link.py | 3 ++- app/components/outcome.py | 6 +++--- app/components/parallel.py | 1 + 5 files changed, 9 insertions(+), 28 deletions(-) diff --git a/app/app.py b/app/app.py index 72cea2d..5ca2775 100644 --- a/app/app.py +++ b/app/app.py @@ -8,7 +8,6 @@ from app.components.intro import IntroComponent from app.components.context import ContextComponent from app.components.filter import FilterComponent -# from app.components.parallel import ParallelComponent from app.components.outcome import OutcomeComponent from app.components.link import LinkComponent from app.components.references import ReferencesComponent @@ -22,9 +21,6 @@ intro_component = IntroComponent() context_component = ContextComponent() filter_component = FilterComponent(metrics) -# parallel_component = ParallelComponent(evolution_handler.load_initial_metrics_df(), -# sample_idxs, -# evolution_handler.outcomes) outcome_component = OutcomeComponent(evolution_handler) link_component = LinkComponent(sample_idxs) references_component = ReferencesComponent() @@ -36,7 +32,6 @@ context_component.register_callbacks(app) filter_component.register_callbacks(app) -# parallel_component.register_callbacks(app) outcome_component.register_callbacks(app) link_component.register_callbacks(app) @@ -46,7 +41,6 @@ intro_component.create_intro_div(), context_component.create_context_div(), filter_component.create_filter_div(), - # parallel_component.create_parallel_div(), outcome_component.create_outcomes_div(), link_component.create_link_div(), references_component.create_references_div() diff --git a/app/components/filter.py b/app/components/filter.py index 40dac0d..b16621f 100644 --- a/app/components/filter.py +++ b/app/components/filter.py @@ -53,16 +53,15 @@ def create_metric_sliders(self): ] ) return div - + def plot_parallel_coordinates_line(self, metrics_json: dict[str, list], metric_ranges: list[tuple[float, float]]) -> go.Figure: """ - NOTE: This is legacy code that may be brought back in later for a user toggle. Plots a parallel coordinates plot of the prescriptor metrics. - Starts by plotting "other" if selected so that it's the bottom of the z axis. + Starts by plotting "other" so that it's the bottom of the z axis. Then plots selected candidates in color. - Finally plots the baseline on top if selected. + Finally plots the baseline on top. """ fig = go.Figure() @@ -157,7 +156,6 @@ def create_filter_div(self): dbc.Accordion( dbc.AccordionItem(dcc.Graph(id="parcoords-figure"), title="View Parallel Coordinates"), start_collapsed=True, - flush=True ), className="w-100" ) @@ -199,19 +197,6 @@ def update_filter_sliders(metrics_jsonl: list[dict[str, list]]) -> list: return total_output - # @app.callback( - # Output("parcoords-collapse", "is_open"), - # Input("parcoords-collapse-button", "n_clicks"), - # State("parcoords-collapse", "is_open") - # ) - # def toggle_parcoords_collapse(n, is_open): - # """ - # Toggles collapse. From dbc documentation. - # """ - # if n: - # return not is_open - # return is_open - @app.callback( Output("parcoords-figure", "figure"), State("metrics-store", "data"), diff --git a/app/components/link.py b/app/components/link.py index 1e220c9..f97fbb4 100644 --- a/app/components/link.py +++ b/app/components/link.py @@ -109,6 +109,7 @@ def create_link_div(self): className=DESC_TEXT), dcc.Loading( type="circle", + target_components={"energy-policy-store": "*"}, children=[ dcc.Store(id="energy-policy-store"), dcc.Graph(id="energy-policy-graph", className="mb-2") @@ -157,7 +158,7 @@ def register_callbacks(self, app): def update_energy_policy_graph(energy_policy_jsonl: list[dict[str, list]], cand_idx) -> go.Figure: if cand_idx is not None: return self.plot_energy_policy(energy_policy_jsonl, cand_idx) - + # If we have no cand id just return a blank figure asking the user to select a candidate. fig = go.Figure() fig.update_layout( diff --git a/app/components/outcome.py b/app/components/outcome.py index 2b18a35..c468a3b 100644 --- a/app/components/outcome.py +++ b/app/components/outcome.py @@ -133,11 +133,11 @@ def create_outcomes_div(self): target_components={"context-actions-store": "*", "outcomes-store": "*"}, type="circle", children=[ - dcc.Store(id="context-actions-store"), - dcc.Store(id="outcomes-store"), dbc.Row( className="g-0", children=[ + dcc.Store(id="context-actions-store"), + dcc.Store(id="outcomes-store"), dbc.Col(dcc.Graph(id="outcome-graph-1"), width=6), dbc.Col(dcc.Graph(id="outcome-graph-2"), width=6) ] @@ -211,7 +211,7 @@ def update_outcomes_plots(metrics_json, outcome1, outcome2, outcomes_jsonl, *met fig1 = self.plot_outcome_over_time(outcome1, outcomes_jsonl, cand_idxs) fig2 = self.plot_outcome_over_time(outcome2, outcomes_jsonl, cand_idxs) return fig1, fig2 - + @app.callback( Output("cand-link-select", "options"), State("metrics-store", "data"), diff --git a/app/components/parallel.py b/app/components/parallel.py index c65f8b9..058cb97 100644 --- a/app/components/parallel.py +++ b/app/components/parallel.py @@ -1,4 +1,5 @@ """ +NOTE: This is legacy code that may be brought back in later. File containing component in charge of visualizing the candidates' metrics. """ from dash import html, dcc, Input, Output From 2e308e638660963a2dd60256b2cb66750d5b2d24 Mon Sep 17 00:00:00 2001 From: Daniel Young Date: Wed, 11 Sep 2024 10:53:49 -0700 Subject: [PATCH 10/11] Updated lint command to not get bad files --- .github/workflows/enroads.yml | 2 +- .pylintrc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/enroads.yml b/.github/workflows/enroads.yml index 7335a97..20c1769 100644 --- a/.github/workflows/enroads.yml +++ b/.github/workflows/enroads.yml @@ -29,7 +29,7 @@ jobs: ENROADS_PASSWORD: ${{ secrets.ENROADS_PASSWORD }} run: python -m enroadspy.download_sdk - name: Lint with PyLint - run: pylint ./* + run: pylint . - name: Lint with Flake8 run: flake8 - name: Run unit tests diff --git a/.pylintrc b/.pylintrc index f44050d..9904ae5 100644 --- a/.pylintrc +++ b/.pylintrc @@ -10,4 +10,4 @@ suggestion-mode=yes good-names=X,F,X0 -fail-under=9.6 \ No newline at end of file +fail-under=9.8 \ No newline at end of file From 9300b3aca08fc330679f92d3600d805083b96978 Mon Sep 17 00:00:00 2001 From: Daniel Young Date: Wed, 11 Sep 2024 11:13:05 -0700 Subject: [PATCH 11/11] Reworded filter description text --- app/components/filter.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/app/components/filter.py b/app/components/filter.py index b16621f..25835c8 100644 --- a/app/components/filter.py +++ b/app/components/filter.py @@ -137,9 +137,7 @@ def create_filter_div(self): html.H2("Filter Policies by Desired Behavior", className=HEADER), html.P("One hundred AI models are trained to create energy policies that make different trade \ offs in metrics. Use the sliders below to filter the AI generated policies \ - that produce desired behavior resulting from their automatically generated energy \ - policy. See how this filtering affects the behavior of the policies in the below \ - sections.", + that produce desired behavior. See the results of the filtering in the below sections.", className=DESC_TEXT), html.Div( dcc.Loading(