Skip to content

Commit

Permalink
Merge pull request #15 from danyoungday/misc-cleanup
Browse files Browse the repository at this point in the history
Misc cleanup
  • Loading branch information
danyoungday authored Sep 20, 2024
2 parents c984d6d + 30af11e commit e137a6b
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 87 deletions.
16 changes: 6 additions & 10 deletions app/components/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def create_context_div(self):
# TODO: Make the box big enough to fit the text
html.Div(
id="ssp-desc",
children=[self.construct_ssp_desc(0)],
children=[html.H4("Select a Scenario")],
className="flex-grow-1 overflow-auto border rounded-3 p-2",
style={"height": "275px"}
)
Expand All @@ -164,7 +164,7 @@ def create_context_div(self):
dbc.Button(
"AI Generate Policies for Scenario",
id="presc-button",
className="me-1",
className="me-1 mb-2",
n_clicks=0
)
]
Expand Down Expand Up @@ -194,19 +194,15 @@ def register_callbacks(self, app):
"""
@app.callback(
[Output(f"context-slider-{i}", "value") for i in range(4)],
Input("context-scatter", "clickData")
Input("context-scatter", "clickData"),
prevent_initial_call=True
)
def click_context(click_data):
"""
Updates context sliders when a context point is clicked.
TODO: Sometimes this function lags, not sure why.
"""
if click_data:
# TODO: This assumes the SSPs in the ssps.csv file are in order which they are
scenario = int(click_data["points"][0]["pointNumber"])
else:
scenario = 0

# TODO: This assumes the SSPs in the ssps.csv file are in order which they are
scenario = int(click_data["points"][0]["pointNumber"])
scenario = f"SSP{scenario+1}-Baseline"
row = self.context_df[self.context_df["scenario"] == scenario].iloc[0]
return [row[self.context_cols[i]] for i in range(4)]
Expand Down
34 changes: 27 additions & 7 deletions app/components/filter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Component in charge of filtering out prescriptors by metric.
"""
import json

from dash import html, dcc, Input, Output, State
import dash_bootstrap_components as dbc
import pandas as pd
Expand All @@ -22,9 +24,14 @@ def __init__(self, metrics: list[str]):
self.metric_ids = [metric.replace(" ", "-").replace(".", "_") for metric in self.metrics]
self.updated_params = ["min", "max", "value", "marks"]

with open("app/units.json", "r", encoding="utf-8") as f:
self.units = json.load(f)

def create_metric_sliders(self):
"""
Creates initial metric sliders and lines them up with their labels.
TODO: We need to stop hard-coding their names and adjustments.
TODO: Add a tooltip to the sliders to show their units.
"""
sliders = []
for metric in self.metrics:
Expand All @@ -36,16 +43,22 @@ def create_metric_sliders(self):
value=[0, 1],
marks={0: f"{0:.2f}", 1: f"{1:.2f}"},
tooltip={"placement": "bottom", "always_visible": True},
allowCross=False
allowCross=False,
disabled=True
)
sliders.append(slider)

names_map = dict(zip(self.metrics, ["Temperature change from 1850",
"Highest cost of energy",
"Government spending",
"Reduction in energy demand"]))
# w-25 and flex-grow-1 ensures they line up
div = html.Div(
children=[
html.Div(
className="d-flex flex-row mb-2",
children=[
html.Label(self.metrics[i], className="w-25"), # w-25 and flex-grow-1 ensures they line up
html.Label(f"{names_map[self.metrics[i]]} ({self.units[self.metrics[i]]})", className="w-25"),
html.Div(sliders[i], className="flex-grow-1")
]
)
Expand Down Expand Up @@ -161,7 +174,7 @@ def create_filter_div(self):
className="me-1",
style={"width": "200px"} # TODO: We hard-code the width here because of text size
),
dbc.Button("Reset Filters", id="reset-button")
dbc.Button("Reset Filters", id="reset-button", disabled=True)
]
),
html.Div(
Expand All @@ -183,14 +196,18 @@ def register_callbacks(self, app):
"""
@app.callback(
[Output(f"{metric_id}-slider", param) for metric_id in self.metric_ids for param in self.updated_params],
[Output(f"{metric_id}-slider", "disabled") for metric_id in self.metric_ids],
Output("reset-button", "disabled"),
Input("metrics-store", "data"),
Input("reset-button", "n_clicks")
Input("reset-button", "n_clicks"),
prevent_initial_call=True
)
def update_filter_sliders(metrics_jsonl: list[dict[str, list]], _) -> list:
"""
Update the filter slider min/max/value/marks based on the incoming metrics data. The output of this function
is a list of the updated parameters for each slider concatenated.
This also happens whenever we click the reset button.
The reset button starts disabled but once the sliders are updated for the first time it becomes enabled.
"""
metrics_df = pd.DataFrame(metrics_jsonl)
total_output = []
Expand All @@ -208,13 +225,15 @@ def update_filter_sliders(metrics_jsonl: list[dict[str, list]], _) -> list:
{min_val_rounded: f"{min_val_rounded:.2f}", max_val_rounded: f"{max_val_rounded:.2f}"}
]
total_output.extend(metric_output)

total_output.extend([False] * len(self.metric_ids)) # Enable all sliders
total_output.append(False) # Enable reset button
return total_output

@app.callback(
Output("parcoords-figure", "figure"),
State("metrics-store", "data"),
[Input(f"{metric_id}-slider", "value") for metric_id in self.metric_ids]
[Input(f"{metric_id}-slider", "value") for metric_id in self.metric_ids],
prevent_initial_call=True
)
def filter_parcoords_figure(metrics_json: dict[str, list], *metric_ranges) -> go.Figure:
"""
Expand All @@ -225,7 +244,8 @@ def filter_parcoords_figure(metrics_json: dict[str, list], *metric_ranges) -> go
@app.callback(
Output("cand-counter", "children"),
State("metrics-store", "data"),
[Input(f"{metric_id}-slider", "value") for metric_id in self.metric_ids]
[Input(f"{metric_id}-slider", "value") for metric_id in self.metric_ids],
prevent_initial_call=True
)
def count_selected_cands(metrics_json: dict[str, list], *metric_ranges) -> str:
"""
Expand Down
35 changes: 28 additions & 7 deletions app/components/intro.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,34 @@ def create_intro_div(self):
html.H2("Decision Making for Climate Change", className="display-4 w-50 mx-auto text-center mb-3")
),
dbc.Row(
html.P("Immediate action is required to combat climate change. \
The technology behind Cognizant NeuroAI brings automatic decision-making to the En-ROADS \
platform, a powerful climate change simulator. A decision-maker can be ready for any \
scenario: choosing an automatically generated policy that suits their needs best, with the \
ability to manually modify the policy and see its results. This tool is brought together \
under Project Resilience, a United Nations initiative to use AI for good.",
className="lead w-50 mx-auto text-center")
html.P(
[
"Immediate action is required to combat climate change. The technology behind ",
html.A(
"Cognizant NeuroAI",
href="https://www.cognizant.com/us/en/services/ai/ai-lab",
style={"color": "black"}
),
" brings automatic decision-making to the ",
html.A(
"En-ROADS platform",
href="https://www.climateinteractive.org/en-roads/",
style={"color": "black"}
),
", a powerful climate change simulator. A decision-maker can be ready for any \
scenario: choosing an automatically generated policy that suits their needs best, with the \
ability to manually modify the policy and see its results. This tool is brought together \
under ",
html.A(
"Project Resilience",
href="https://www.itu.int/en/ITU-T/extcoop/ai-data-commons/\
Pages/project-resilience.aspx",
style={"color": "black"}
),
", a United Nations initiative to use AI for good."
],
className="lead w-50 mx-auto text-center"
)
),
dbc.Row(
style={"height": "60vh"}
Expand Down
18 changes: 12 additions & 6 deletions app/components/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,18 @@ def create_link_div(self):
Then click on the link to explore and fine-tune the policy in En-ROADS.",
className=DESC_TEXT),
html.Div(
dcc.Dropdown(
id="cand-link-select",
options=[],
placeholder="Select a policy",
),
className="w-25 flex-grow-1"
className="d-flex flex-row w-25 justify-content-center",
children=[
html.Label("Policy: ", className="pt-1 me-1"),
html.Div(
dcc.Dropdown(
id="cand-link-select",
options=[],
placeholder="Select a policy",
),
className="flex-grow-1"
)
]
),
dcc.Loading(
type="circle",
Expand Down
85 changes: 65 additions & 20 deletions app/components/outcome.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""
OutcomeComponent class for the outcome section of the app.
"""
from dash import Input, Output, State, html, dcc
import json

from dash import Input, Output, State, html, dcc, MATCH
import dash_bootstrap_components as dbc
import pandas as pd
import plotly.express as px
Expand Down Expand Up @@ -30,6 +32,9 @@ def __init__(self, evolution_handler: EvolutionHandler):

self.metric_ids = [metric.replace(" ", "-").replace(".", "_") for metric in self.evolution_handler.outcomes]

with open("app/units.json", "r", encoding="utf-8") as f:
self.units = json.load(f)

def plot_outcome_over_time(self, outcome: str, outcomes_jsonl: list[list[dict[str, float]]], cand_idxs: list[int]):
"""
Plots all the candidates' prescribed actions' outcomes for a given context.
Expand Down Expand Up @@ -120,7 +125,8 @@ def plot_outcome_over_time(self, outcome: str, outcomes_jsonl: list[list[dict[st
"text": f"{outcome} Over Time",
"x": 0.5,
"xanchor": "center"},
yaxis_range=[y_min, y_max]
yaxis_range=[y_min, y_max],
yaxis_title=outcome + f" ({self.units[outcome]})"
)
return fig

Expand All @@ -142,33 +148,71 @@ def create_outcomes_div(self):
children=[
html.Div(
dcc.Dropdown(
id="outcome-dropdown-1",
id={"type": "outcome-dropdown", "index": 0},
options=self.plot_outcomes,
value=self.plot_outcomes[0]
value=self.plot_outcomes[0],
disabled=True
),
className="flex-fill"
),
html.Div(
dcc.Dropdown(
id="outcome-dropdown-2",
id={"type": "outcome-dropdown", "index": 1},
options=self.plot_outcomes,
value=self.plot_outcomes[1]
value=self.plot_outcomes[1],
disabled=True
),
className="flex-fill"
)
]
),
dcc.Loading(
target_components={"context-actions-store": "*", "outcomes-store": "*"},
target_components={"context-actions-store": "*"},
type="circle",
children=[
dbc.Row(
className="g-0",
children=[
dcc.Store(id="context-actions-store"),
dbc.Col(dcc.Graph(id={"type": "outcome-graph", "index": 0}), width=6),
dbc.Col(dcc.Graph(id={"type": "outcome-graph", "index": 1}), width=6)
]
)
]
),
html.Div(
className="d-flex flex-row w-100",
children=[
html.Div(
dcc.Dropdown(
id={"type": "outcome-dropdown", "index": 2},
options=self.plot_outcomes,
value=self.plot_outcomes[2],
disabled=True
),
className="flex-fill"
),
html.Div(
dcc.Dropdown(
id={"type": "outcome-dropdown", "index": 3},
options=self.plot_outcomes,
value=self.plot_outcomes[3],
disabled=True
),
className="flex-fill"
)
]
),
dcc.Loading(
target_components={"outcomes-store": "*"},
type="circle",
children=[
dbc.Row(
className="g-0",
children=[
dcc.Store(id="outcomes-store"),
dbc.Col(dcc.Graph(id="outcome-graph-1"), width=6),
dbc.Col(dcc.Graph(id="outcome-graph-2"), width=6)
dbc.Col(dcc.Graph(id={"type": "outcome-graph", "index": 2}), width=6),
dbc.Col(dcc.Graph(id={"type": "outcome-graph", "index": 3}), width=6)
]
)
]
Expand All @@ -190,7 +234,8 @@ def register_callbacks(self, app):
Output("metrics-store", "data"),
Output("energy-policy-store", "data"),
Input("presc-button", "n_clicks"),
[State(f"context-slider-{i}", "value") for i in range(4)]
[State(f"context-slider-{i}", "value") for i in range(4)],
prevent_initial_call=True
)
def update_results_stores(_, *context_values):
"""
Expand Down Expand Up @@ -223,29 +268,29 @@ def update_results_stores(_, *context_values):
return context_actions_dicts, outcomes_jsonl, metrics_json, energy_policy_jsonl

@app.callback(
Output("outcome-graph-1", "figure"),
Output("outcome-graph-2", "figure"),
Output({"type": "outcome-graph", "index": MATCH}, "figure"),
Output({"type": "outcome-dropdown", "index": MATCH}, "disabled"),
State("metrics-store", "data"),
Input("outcome-dropdown-1", "value"),
Input("outcome-dropdown-2", "value"),
Input({"type": "outcome-dropdown", "index": MATCH}, "value"),
Input("outcomes-store", "data"),
[Input(f"{metric_id}-slider", "value") for metric_id in self.metric_ids],
prevent_initial_call=True
)
def update_outcomes_plots(metrics_json, outcome1, outcome2, outcomes_jsonl, *metric_ranges):
def update_outcomes_plots(metrics_json, outcome, outcomes_jsonl, *metric_ranges):
"""
Updates outcome plot when specific outcome is selected or context scatter point is clicked.
We also un-disable the dropdowns when the user selects a context.
"""
metrics_df = filter_metrics_json(metrics_json, metric_ranges)
cand_idxs = list(metrics_df.index)[:-1] # So we don't include the baseline

fig1 = self.plot_outcome_over_time(outcome1, outcomes_jsonl, cand_idxs)
fig2 = self.plot_outcome_over_time(outcome2, outcomes_jsonl, cand_idxs)
return fig1, fig2
fig = self.plot_outcome_over_time(outcome, outcomes_jsonl, cand_idxs)
return fig, False

@app.callback(
Output("cand-link-select", "options"),
State("metrics-store", "data"),
[Input(f"{metric_id}-slider", "value") for metric_id in self.metric_ids]
[Input(f"{metric_id}-slider", "value") for metric_id in self.metric_ids],
prevent_initial_call=True
)
def update_cand_link_select(metrics_json: dict[str, list],
*metric_ranges: list[tuple[float, float]]) -> list[int]:
Expand Down
Loading

0 comments on commit e137a6b

Please sign in to comment.