Skip to content

Commit

Permalink
Merge pull request #10 from danyoungday/presc-select
Browse files Browse the repository at this point in the history
Prescriptor select
  • Loading branch information
danyoungday authored Sep 11, 2024
2 parents 2538948 + 9300b3a commit 521a8f2
Show file tree
Hide file tree
Showing 12 changed files with 464 additions and 156 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/enroads.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
ENROADS_PASSWORD: ${{ secrets.ENROADS_PASSWORD }}
run: python -m enroadspy.download_sdk
- name: Lint with PyLint
run: pylint ./*
run: pylint .
- name: Lint with Flake8
run: flake8
- name: Run unit tests
Expand Down
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ suggestion-mode=yes

good-names=X,F,X0

fail-under=9.6
fail-under=9.8
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# Decision Making for Climate Change

Immediate action is required to combat climate change. The technology behind Cognizant NeuroAI brings automatic decision-making to the En-ROADS platform, a powerful climate change simulator. A decision-maker can be ready for any scenario: choosing an automatically generated policy that suits their needs best, with the ability to manually modify the policy and see its results. This tool is brought together under Project Resilience, a United Nations initiative to use AI for good.
Immediate action is required to combat climate change. The technology behind [Cognizant NeuroAI](https://evolution.ml/) brings automatic decision-making to the En-ROADS platform, a powerful climate change simulator. A decision-maker can be ready for any scenario: choosing an automatically generated policy that suits their needs best, with the ability to manually modify the policy and see its results. This tool is brought together under Project Resilience, a United Nations initiative to use AI for good.

## En-ROADS Wrapper

En-ROADS is a climate change simulator developed by Climate Interactive. We have created a wrapper around the SDK to make it simple to use in a Python application which can be found in `enroadspy`. See `enroads_runner.py` for the main class that runs the SDK. The SDK is not included in this repository and must be requested from Climate Interactive.
En-ROADS is a climate change simulator developed by Climate Interactive. We have created a wrapper around the SDK to make it simple to use in a Python application which can be found in `enroadspy`. See `enroads_runner.py` for the main class that runs the SDK. The SDK is not included in this repository and must be requested from [Climate Interactive](https://www.climateinteractive.org/).

The input data format is a crazy long JSON object which I copied out of the source code, pasted into `inputSpecs.py`, and parsed into `inputSpecs.jsonl`. This format is used by the rest of the code.
The input data format is a crazy long JSON object which I copied out of the source code, pasted into `inputSpecs.py`, and parsed into `inputSpecs.jsonl`. This format is used by the rest of the repository.

### Installation
Download the en-roads zip file, place it in the root folder of the repository, and unzip it.
Run `pip install -r requirements.txt` to install the required packages. Then run `python -m enroadspy.download_sdk` to download the SDK. In order to download the SDK environment variables must be set which can be requested online from Climate Interactive.

## Evolution

Expand Down
15 changes: 7 additions & 8 deletions app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,21 @@

from app.components.intro import IntroComponent
from app.components.context import ContextComponent
from app.components.filter import FilterComponent
from app.components.outcome import OutcomeComponent
from app.components.parallel import ParallelComponent
from app.components.link import LinkComponent
from app.components.references import ReferencesComponent
from app.utils import EvolutionHandler

evolution_handler = EvolutionHandler()
metrics = evolution_handler.outcomes.keys()
# The candidates are sorted by rank then distance so the 'best' ones are the first 10
sample_idxs = list(range(10))

intro_component = IntroComponent()
context_component = ContextComponent()
parallel_component = ParallelComponent(evolution_handler.load_initial_metrics_df(),
sample_idxs,
evolution_handler.outcomes)
outcome_component = OutcomeComponent(evolution_handler, sample_idxs)
filter_component = FilterComponent(metrics)
outcome_component = OutcomeComponent(evolution_handler)
link_component = LinkComponent(sample_idxs)
references_component = ReferencesComponent()

Expand All @@ -32,7 +31,7 @@
app.title = "Climate Change Decision Making"

context_component.register_callbacks(app)
parallel_component.register_callbacks(app)
filter_component.register_callbacks(app)
outcome_component.register_callbacks(app)
link_component.register_callbacks(app)

Expand All @@ -41,7 +40,7 @@
children=[
intro_component.create_intro_div(),
context_component.create_context_div(),
parallel_component.create_parallel_div(),
filter_component.create_filter_div(),
outcome_component.create_outcomes_div(),
link_component.create_link_div(),
references_component.create_references_div()
Expand All @@ -50,4 +49,4 @@

# Run the app
if __name__ == '__main__':
app.run_server(host='0.0.0.0', debug=False, port=4057, use_reloader=False, threaded=True)
app.run_server(host='0.0.0.0', debug=False, port=4057, use_reloader=True, threaded=True)
10 changes: 10 additions & 0 deletions app/classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
Hard-coded combinations of classes using bootstrap akin to custom css classes.
"""
DESC_TEXT = "mb-5 w-75 text-center"

JUMBOTRON = "p-3 bg-white rounded-5 mx-auto w-75 mb-3"

CONTAINER = "py-3 d-flex flex-column h-100 align-items-center"

HEADER = "text-center mb-5"
20 changes: 9 additions & 11 deletions app/components/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import plotly.express as px
import plotly.graph_objects as go

from app.classes import JUMBOTRON, CONTAINER, DESC_TEXT, HEADER
from enroadspy import load_input_specs


Expand Down Expand Up @@ -123,21 +124,18 @@ def create_context_div(self):
)

div = html.Div(
className="p-3 bg-white rounded-5 mx-auto w-75 mb-3",
className=JUMBOTRON,
children=[
dbc.Container(
fluid=True,
className="py-3 d-flex flex-column h-100",
className=CONTAINER,
children=[
dbc.Row(html.H2("Select a Context Scenario to Optimize For", className="text-center mb-5")),
dbc.Row(
className="mb-2 w-70 text-center mx-auto",
children=[html.P("According to the AR6 climate report: 'The five Shared Socioeconomic \
Pathways were designed to span a range of challenges to climate change \
mitigation and adaptation.' Select one of these scenarios by clicking it \
in the scatter plot below. If desired, manually modify the scenario \
with the sliders.")]
),
html.H2("Select a Context Scenario to Optimize For", className=HEADER),
html.P("According to the AR6 climate report: 'The five Shared Socioeconomic \
Pathways were designed to span a range of challenges to climate change \
mitigation and adaptation'. Select one of these scenarios by clicking it \
in the scatter plot below. If desired, manually modify the scenario \
with the sliders.", className=DESC_TEXT),
dbc.Row(
className="flex-grow-1",
children=[
Expand Down
207 changes: 207 additions & 0 deletions app/components/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""
Component in charge of filtering out prescriptors by metric.
"""
from dash import html, dcc, Input, Output, State
import dash_bootstrap_components as dbc
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from app.classes import JUMBOTRON, CONTAINER, DESC_TEXT, HEADER
from app.utils import filter_metrics_json


class FilterComponent:
"""
Component in charge of filtering out prescriptors by metric specific to each context.
The component stores the metrics to filter with as long as their corresponding HTML ids.
It also keeps track of the parameters that need to be updated for each slider.
"""
def __init__(self, metrics: list[str]):
self.metrics = list(metrics)
self.metric_ids = [metric.replace(" ", "-").replace(".", "_") for metric in self.metrics]
self.updated_params = ["min", "max", "value", "marks"]

def create_metric_sliders(self):
"""
Creates initial metric sliders and lines them up with their labels.
"""
sliders = []
for metric in self.metrics:
col_id = metric.replace(" ", "-").replace(".", "_")
slider = dcc.RangeSlider(
id=f"{col_id}-slider",
min=0,
max=1,
value=[0, 1],
marks={0: f"{0:.2f}", 1: f"{1:.2f}"},
tooltip={"placement": "bottom", "always_visible": True},
allowCross=False
)
sliders.append(slider)

div = html.Div(
children=[
html.Div(
className="d-flex flex-row mb-2",
children=[
html.Label(self.metrics[i], className="w-25"), # w-25 and flex-grow-1 ensures they line up
html.Div(sliders[i], className="flex-grow-1")
]
)
for i in range(len(self.metrics))
]
)
return div

def plot_parallel_coordinates_line(self,
metrics_json: dict[str, list],
metric_ranges: list[tuple[float, float]]) -> go.Figure:
"""
Plots a parallel coordinates plot of the prescriptor metrics.
Starts by plotting "other" so that it's the bottom of the z axis.
Then plots selected candidates in color.
Finally plots the baseline on top.
"""
fig = go.Figure()

normalized_df = filter_metrics_json(metrics_json, metric_ranges, normalize=True)

cand_idxs = list(normalized_df.index)[:-1] # Leave out the baseline
n_special_cands = min(10, len(cand_idxs))

showlegend = True
# If "other" is in the cand_idxs, plot all other candidates in lightgray
for cand_idx in cand_idxs[n_special_cands:]:
cand_metrics = normalized_df.loc[cand_idx].values
fig.add_trace(go.Scatter(
x=normalized_df.columns,
y=cand_metrics,
mode='lines',
legendgroup="other",
name="other",
line=dict(color="lightgray"),
showlegend=showlegend
))
showlegend = False

# Plot selected candidates besides baseline so it can be on top
for color_idx, cand_idx in enumerate(cand_idxs[:n_special_cands]):
cand_metrics = normalized_df.loc[cand_idx].values
fig.add_trace(go.Scatter(
x=normalized_df.columns,
y=cand_metrics,
mode='lines',
name=str(cand_idx),
line=dict(color=px.colors.qualitative.Plotly[color_idx])
))

baseline_metrics = normalized_df.iloc[-1]
fig.add_trace(go.Scatter(
x=normalized_df.columns,
y=baseline_metrics.values,
mode='lines',
name="baseline",
line=dict(color="black")
))

for i in range(len(normalized_df.columns)):
fig.add_vline(x=i, line_color="black")

full_metrics_df = pd.DataFrame(metrics_json)
normalized_full = (full_metrics_df - full_metrics_df.mean()) / (full_metrics_df.std() + 1e-10)
fig.update_layout(
yaxis_range=[normalized_full.min().min(), normalized_full.max().max()],
title={
'text': "Normalized Policy Metrics",
'x': 0.5, # Center the title
'xanchor': 'center', # Anchor it at the center
'yanchor': 'top' # Optionally keep it anchored to the top
}
)

return fig

def create_filter_div(self):
"""
Creates div showing sliders to choose the range of metric values we want the prescriptors to have.
TODO: Currently the slider tooltips show even while loading which is a bit of an eyesore.
"""
div = html.Div(
className=JUMBOTRON,
children=[
dbc.Container(
fluid=True,
className=CONTAINER,
children=[
html.H2("Filter Policies by Desired Behavior", className=HEADER),
html.P("One hundred AI models are trained to create energy policies that make different trade \
offs in metrics. Use the sliders below to filter the AI generated policies \
that produce desired behavior. See the results of the filtering in the below sections.",
className=DESC_TEXT),
html.Div(
dcc.Loading(
type="circle",
target_components={"metrics-store": "*"},
children=[
self.create_metric_sliders(),
dcc.Store(id="metrics-store")
],
),
className="w-100 mb-5"
),
html.Div(
dbc.Accordion(
dbc.AccordionItem(dcc.Graph(id="parcoords-figure"), title="View Parallel Coordinates"),
start_collapsed=True,
),
className="w-100"
)
]
)
]
)
return div

def register_callbacks(self, app):
"""
Registers callbacks related to the filter sliders.
"""
@app.callback(
[Output(f"{metric_id}-slider", param) for metric_id in self.metric_ids for param in self.updated_params],
Input("metrics-store", "data")
)
def update_filter_sliders(metrics_jsonl: list[dict[str, list]]) -> list:
"""
Update the filter slider min/max/value/marks based on the incoming metrics data. The output of this function
is a list of the updated parameters for each slider concatenated.
"""
metrics_df = pd.DataFrame(metrics_jsonl)
total_output = []
for metric in self.metrics:
metric_output = []
min_val = metrics_df[metric].min()
max_val = metrics_df[metric].max()
# We need to round down for the min value and round up for the max value
min_val_rounded = min_val // 0.01 / 100
max_val_rounded = max_val + 0.01
metric_output = [
min_val_rounded,
max_val_rounded,
[min_val_rounded, max_val_rounded],
{min_val_rounded: f"{min_val_rounded:.2f}", max_val_rounded: f"{max_val_rounded:.2f}"}
]
total_output.extend(metric_output)

return total_output

@app.callback(
Output("parcoords-figure", "figure"),
State("metrics-store", "data"),
[Input(f"{metric_id}-slider", "value") for metric_id in self.metric_ids]
)
def filter_parcoords_figure(metrics_json: dict[str, list], *metric_ranges) -> go.Figure:
"""
Filters parallel coordinates figure based on the metric ranges from the sliders.
"""
return self.plot_parallel_coordinates_line(metrics_json, metric_ranges)
Loading

0 comments on commit 521a8f2

Please sign in to comment.