From c0321f167f72aee43b4b5f7e921501c66d7a16eb Mon Sep 17 00:00:00 2001 From: akore Date: Tue, 14 Nov 2023 09:57:11 -0500 Subject: [PATCH 1/4] add timestamps and add plotly.js render instead of div --- cyclops/report/model_card/fields.py | 4 +- cyclops/report/report.py | 9 +- .../templates/cyclops_generic_template.jinja | 108 +++++++++++++++--- cyclops/report/utils.py | 90 +++++++-------- .../kaggle/heart_failure_prediction.ipynb | 23 +++- tests/cyclops/report/test_utils.py | 21 +++- 6 files changed, 180 insertions(+), 75 deletions(-) diff --git a/cyclops/report/model_card/fields.py b/cyclops/report/model_card/fields.py index 60b973841..d4cdf9c5c 100644 --- a/cyclops/report/model_card/fields.py +++ b/cyclops/report/model_card/fields.py @@ -599,9 +599,9 @@ class MetricCard( description="The trend of the metric over time.", ) - plot: Optional[GraphicsCollection] = Field( + timestamps: Optional[List[StrictStr]] = Field( None, - description="A plot of the performance over time.", + description="Timestamps for each point in the history.", ) diff --git a/cyclops/report/report.py b/cyclops/report/report.py index 80c3b7af1..3f8fa3ff2 100644 --- a/cyclops/report/report.py +++ b/cyclops/report/report.py @@ -45,11 +45,12 @@ _raise_if_not_dict_with_str_keys, create_metric_cards, empty, + get_histories, get_names, get_passed, - get_plots, get_slices, get_thresholds, + get_timestamps, get_trends, regex_replace, regex_search, @@ -1089,8 +1090,10 @@ def export( # write to file if synthetic_timestamp is not None: today = synthetic_timestamp + today_now = synthetic_timestamp else: today = dt_date.today().strftime("%Y-%m-%d") + today_now = dt_datetime.now().strftime("%Y-%m-%d %H:%M:%S") current_report_metrics: List[List[PerformanceMetric]] = [] sweep_metrics(self._model_card, current_report_metrics) @@ -1121,6 +1124,7 @@ def export( # compare tests metrics, tooltips, slices, values, metric_cards = create_metric_cards( current_report_metrics_set, + today_now, latest_report_metric_cards_set, ) self._log_metric_card_collection( @@ -1138,11 +1142,12 @@ def export( "sweep_tests": sweep_tests, "sweep_graphics": sweep_graphics, "get_slices": get_slices, - "get_plots": get_plots, "get_thresholds": get_thresholds, "get_trends": get_trends, "get_passed": get_passed, "get_names": get_names, + "get_histories": get_histories, + "get_timestamps": get_timestamps, } template.globals.update(func_dict) diff --git a/cyclops/report/templates/cyclops_generic_template.jinja b/cyclops/report/templates/cyclops_generic_template.jinja index 6b8ef4455..2569a71b3 100644 --- a/cyclops/report/templates/cyclops_generic_template.jinja +++ b/cyclops/report/templates/cyclops_generic_template.jinja @@ -1,5 +1,13 @@ +{# Get indices of all metric cards for 'overall' slice #} +{% set overall_indices = [] %} +{% for metric_card in model_card.overview.metric_cards.collection%} + {% if metric_card.slice == 'overall' %} + {% set _ = overall_indices.append(loop.index-1) %} + {% endif %} +{% endfor %} + {% macro render_if_exist_list(values) %}
{% if values.__class__.__name__ == "User"%} @@ -124,7 +132,7 @@
{% endmacro %} -{% macro render_metric_card(card)%} +{% macro render_metric_card(card, idx)%}
@@ -156,9 +164,8 @@ {{card.threshold}}
minimum
threshold
- {% if card.plot %} - {{ render_graphic(card.plot.collection[0], class="") }} - {% endif %} +
+
@@ -170,7 +177,7 @@

A quick glance of your most important metrics.

{% for metric_card in comp.metric_cards.collection%} {% if metric_card.slice == 'overall' %} - {{ render_metric_card(metric_card) }} + {{ render_metric_card(metric_card, loop.index-1) }} {% endif %} {% endfor %} @@ -274,18 +281,23 @@ selection.sort(); var slices = JSON.parse({{ get_slices(model_card)|safe|tojson }}); - var plots = JSON.parse({{ get_plots(model_card)|safe|tojson }}); + var histories = JSON.parse({{ get_histories(model_card)|safe|tojson }}); var thresholds = JSON.parse({{ get_thresholds(model_card)|safe|tojson }}); var trends = JSON.parse({{ get_trends(model_card)|safe|tojson }}); var passed = JSON.parse({{ get_passed(model_card)|safe|tojson }}); var names = JSON.parse({{ get_names(model_card)|safe|tojson }}); + var timestamps = JSON.parse({{ get_timestamps(model_card)|safe|tojson }}); // get idx of slices where all elements match var idx = Object.keys(slices).find(key => JSON.stringify(slices[key].sort()) === JSON.stringify(selection)); - var plot_data = []; - for (let i = 0; i < plots[idx].length; i++) { - plot_data.push(parseFloat(plots[idx][i])); + var history_data = []; + for (let i = 0; i < histories[idx].length; i++) { + history_data.push(parseFloat(histories[idx][i])); + } + var timestamp_data = []; + for (let i = 0; i < timestamps[idx].length; i++) { + timestamp_data.push(timestamps[idx][i]); } threshold = parseFloat(thresholds[idx]); trend = trends[idx]; @@ -309,24 +321,25 @@ } // create title for plot: Current {metric name} is trending {trend_keyword} and is {passed_keyword} the threshold. var plot_title = "Current " + name + " is trending " + trend_keyword + " and is " + passed_keyword + " the threshold."; - var trace = { // range of x is the length of the list of floats - x: Array.from({length: plot_data.length}, (_, i) => i), - y: plot_data, + x: timestamp_data, + y: history_data, mode: 'lines+markers', type: 'scatter', marker: {color: 'rgb(31,111,235)'}, line: {color: 'rgb(31,111,235)'}, + name: '', }; var threshold_trace = { - x: Array.from({length: plot_data.length}, (_, i) => i), - y: Array.from({length: plot_data.length}, (_, i) => threshold), + x: timestamp_data, + y: Array.from({length: history_data.length}, (_, i) => threshold), mode: 'lines', type: 'scatter', marker: {color: 'rgb(0,0,0)'}, line: {color: 'rgb(0,0,0)', dash: 'dot'}, + name: '', }; var layout = { @@ -925,7 +938,6 @@ return true; } } - function setActiveButton() { const buttons = document.querySelectorAll('#contents li'); const sections = document.querySelectorAll('.card'); @@ -946,7 +958,71 @@ } } } - document.addEventListener('scroll', setActiveButton); setActiveButton(); + + function generate_model_card_plot() { + var model_card_plots = [] + var overall_indices = {{overall_indices}} + var histories = JSON.parse({{ get_histories(model_card)|safe|tojson }}); + var thresholds = JSON.parse({{ get_thresholds(model_card)|safe|tojson }}); + var timestamps = JSON.parse({{ get_timestamps(model_card)|safe|tojson }}); + + for (let i = 0; i < overall_indices.length; i++) { + var idx = overall_indices[i]; + var model_card_plot = "model-card-plot-" + idx; + var threshold = thresholds[idx]; + var history_data = []; + for (let i = 0; i < histories[idx].length; i++) { + history_data.push(parseFloat(histories[idx][i])); + } + var timestamp_data = []; + for (let i = 0; i < timestamps[idx].length; i++) { + timestamp_data.push(timestamps[idx][i]); + } + + var model_card_fig = { + data: [ + { + x: timestamp_data, + y: history_data, + mode: "lines+markers", + marker: { color: "rgb(31,111,235)" }, + line: { color: "rgb(31,111,235)" }, + showlegend: false, + type: "scatter", + name: "" + }, + { + x: timestamp_data, + y: Array(history_data.length).fill(threshold), + mode: "lines", + line: { color: "black", dash: "dot" }, + showlegend: false, + type: "scatter", + name: "" + } + ], + layout: { + paper_bgcolor: "rgba(0,0,0,0)", + plot_bgcolor: "rgba(0,0,0,0)", + xaxis: { + zeroline: false, + showticklabels: false, + showgrid: false + }, + yaxis: { + gridcolor: "#ffffff" + }, + margin: { l: 0, r: 0, t: 0, b: 0 }, + height: 125, + width: 250 + } + }; + if (history.length > 0) { + Plotly.newPlot(model_card_plot, model_card_fig.data, model_card_fig.layout, {displayModeBar: false}); + } + } + } + generate_model_card_plot(); diff --git a/cyclops/report/utils.py b/cyclops/report/utils.py index 7f1174f6a..440282ab0 100644 --- a/cyclops/report/utils.py +++ b/cyclops/report/utils.py @@ -484,24 +484,6 @@ def get_slices(model_card: ModelCard) -> str: return json.dumps(names) -def get_plots(model_card: ModelCard) -> str: - """Get all plots from a model card.""" - plots: Dict[int, Optional[List[str]]] = {} - if ( - (model_card.overview is None) - or (model_card.overview.metric_cards is None) - or (model_card.overview.metric_cards.collection is None) - ): - pass - else: - for itr, metric_card in enumerate(model_card.overview.metric_cards.collection): - if metric_card.plot is not None: - plots[itr] = [str(history) for history in metric_card.history] - else: - plots[itr] = None - return json.dumps(plots) - - def get_thresholds(model_card: ModelCard) -> str: """Get all thresholds from a model card.""" thresholds: Dict[int, Optional[str]] = {} @@ -513,10 +495,7 @@ def get_thresholds(model_card: ModelCard) -> str: pass else: for itr, metric_card in enumerate(model_card.overview.metric_cards.collection): - if metric_card.plot is not None: - thresholds[itr] = str(metric_card.threshold) - else: - thresholds[itr] = None + thresholds[itr] = str(metric_card.threshold) return json.dumps(thresholds) @@ -531,10 +510,7 @@ def get_trends(model_card: ModelCard) -> str: pass else: for itr, metric_card in enumerate(model_card.overview.metric_cards.collection): - if metric_card.plot is not None: - trends[itr] = metric_card.trend - else: - trends[itr] = None + trends[itr] = metric_card.trend return json.dumps(trends) @@ -549,10 +525,7 @@ def get_passed(model_card: ModelCard) -> str: pass else: for itr, metric_card in enumerate(model_card.overview.metric_cards.collection): - if metric_card.plot is not None: - passed[itr] = metric_card.passed - else: - passed[itr] = None + passed[itr] = metric_card.passed return json.dumps(passed) @@ -571,8 +544,39 @@ def get_names(model_card: ModelCard) -> str: return json.dumps(names) +def get_histories(model_card: ModelCard) -> str: + """Get all plots from a model card.""" + plots: Dict[int, Optional[List[str]]] = {} + if ( + (model_card.overview is None) + or (model_card.overview.metric_cards is None) + or (model_card.overview.metric_cards.collection is None) + ): + pass + else: + for itr, metric_card in enumerate(model_card.overview.metric_cards.collection): + plots[itr] = [str(history) for history in metric_card.history] + return json.dumps(plots) + + +def get_timestamps(model_card: ModelCard) -> str: + """Get all timestamps from a model card.""" + timestamps = {} + if ( + (model_card.overview is None) + or (model_card.overview.metric_cards is None) + or (model_card.overview.metric_cards.collection is None) + ): + pass + else: + for itr, metric_card in enumerate(model_card.overview.metric_cards.collection): + timestamps[itr] = metric_card.timestamps + return json.dumps(timestamps) + + def create_metric_cards( # noqa: PLR0912 PLR0915 current_metrics: List[PerformanceMetric], + timestamp: str, last_metric_cards: Optional[List[MetricCard]] = None, ) -> Tuple[ List[str], @@ -676,20 +680,14 @@ def create_metric_cards( # noqa: PLR0912 PLR0915 isinstance(metric["current_metric"].value, float) ): history.append(metric["current_metric"].value) - if ( - isinstance(metric["current_metric"], PerformanceMetric) - and (metric["current_metric"].tests is not None) - and (isinstance(metric["current_metric"].tests[0], Test)) - and (metric["current_metric"].tests[0].threshold is not None) - ): - plot = create_metric_card_plot( - history, - metric["current_metric"].tests[0].threshold, - ) - (m, b) = np.polyfit(range(len(history)), history, deg=1) - if m >= 0.03: + + timestamps = metric["last_metric_card"].timestamps + if timestamps is not None: + timestamps.append(timestamp) + (m, _) = np.polyfit(range(len(history)), history, deg=1) + if m >= 0.01: trend = "positive" - elif m <= -0.03: + elif m <= -0.01: trend = "negative" else: trend = "neutral" @@ -730,9 +728,7 @@ def create_metric_cards( # noqa: PLR0912 PLR0915 trend=trend if isinstance(metric["current_metric"], PerformanceMetric) else None, - plot=plot - if isinstance(metric["current_metric"], PerformanceMetric) - else None, + timestamps=timestamps, ), ) else: @@ -792,7 +788,7 @@ def create_metric_cards( # noqa: PLR0912 PLR0915 else 0, ], trend="neutral", - plot=None, + timestamps=[timestamp], ), ) metrics = list(dict.fromkeys(metrics)) diff --git a/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb b/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb index 40aad9a51..b1339a327 100644 --- a/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb +++ b/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb @@ -1172,9 +1172,20 @@ }, "outputs": [], "source": [ - "report_path = report.export(output_filename=\"heart_failure_report_periodic.html\")\n", + "synthetic_timestamps = [\n", + " \"2021-09-01\",\n", + " \"2021-10-01\",\n", + " \"2021-11-01\",\n", + " \"2021-12-01\",\n", + " \"2022-01-01\",\n", + "]\n", + "report._model_card.overview = None\n", + "report_path = report.export(\n", + " output_filename=\"heart_failure_report_periodic.html\",\n", + " synthetic_timestamp=synthetic_timestamps[0],\n", + ")\n", "shutil.copy(f\"{report_path}\", \".\")\n", - "for _ in range(5):\n", + "for i in range(4):\n", " report._model_card.overview = None\n", " report._model_card.quantitative_analysis = None\n", " results_flat = flatten_results_dict(\n", @@ -1201,8 +1212,12 @@ " pass_fail_thresholds=0.7,\n", " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n", " )\n", - " report_path = report.export(output_filename=\"heart_failure_report_periodic.html\")\n", - " shutil.copy(f\"{report_path}\", \".\")" + " report_path = report.export(\n", + " output_filename=\"heart_failure_report_periodic.html\",\n", + " synthetic_timestamp=synthetic_timestamps[i + 1],\n", + " )\n", + " shutil.copy(f\"{report_path}\", \".\")\n", + "shutil.rmtree(\"./cyclops_reports\")" ] }, { diff --git a/tests/cyclops/report/test_utils.py b/tests/cyclops/report/test_utils.py index 93b2440af..a8120c7ea 100644 --- a/tests/cyclops/report/test_utils.py +++ b/tests/cyclops/report/test_utils.py @@ -26,12 +26,13 @@ extract_performance_metrics, filter_results, flatten_results_dict, + get_histories, get_metrics_trends, get_names, get_passed, - get_plots, get_slices, get_thresholds, + get_timestamps, get_trends, sweep_graphics, sweep_metric_cards, @@ -334,9 +335,17 @@ def test_get_slices(model_card): assert len(slices_dict.values()) == 2 -def test_get_plots(model_card): +def test_get_timestamps(model_card): + """Test get_timestamps function.""" + timestamps = get_timestamps(model_card) + # read timestamps from json to dict + timestamps_dict = json.loads(timestamps) + assert len(timestamps_dict.values()) == 2 + + +def test_get_histories(model_card): """Test get_plots function.""" - plots = get_plots(model_card) + plots = get_histories(model_card) # read plots from json to dict plots_dict = json.loads(plots) assert len(plots_dict.values()) == 2 @@ -376,9 +385,13 @@ def test_get_names(model_card): def test_create_metric_cards(model_card): """Test create_metric_cards function.""" + timestamp = "2021-01-01" current_metrics = [] sweep_metrics(model_card, metrics=current_metrics) - metric_cards = create_metric_cards(current_metrics=current_metrics[0])[-1] + metric_cards = create_metric_cards( + current_metrics=current_metrics[0], + timestamp=timestamp, + )[-1] assert len(metric_cards) == 2 From 7e0725f2aa5a6631c8d17affc80425ad4a64c752 Mon Sep 17 00:00:00 2001 From: akore Date: Mon, 20 Nov 2023 09:48:10 -0500 Subject: [PATCH 2/4] add multi-plot selection and gray out missing slices --- .../templates/cyclops_generic_template.jinja | 790 ++++++++++++++---- cyclops/report/utils.py | 25 +- .../kaggle/heart_failure_prediction.ipynb | 60 ++ tests/cyclops/report/test_utils.py | 3 + 4 files changed, 728 insertions(+), 150 deletions(-) diff --git a/cyclops/report/templates/cyclops_generic_template.jinja b/cyclops/report/templates/cyclops_generic_template.jinja index 2569a71b3..a3761ce09 100644 --- a/cyclops/report/templates/cyclops_generic_template.jinja +++ b/cyclops/report/templates/cyclops_generic_template.jinja @@ -184,13 +184,23 @@ {% endmacro %} {% macro render_perf_over_time(name, comp)%} -
+

How is your model doing over time?


See how your model is performing over several metrics and subgroups over time.

+ {#
#} +
+

Multi-plot Selection:

+
+ + + + +
+

Metrics

-
+
{% if comp.metric_cards.metrics[0]|regex_search("\((.*?)\)")|length != 0 %} {% set acronym = comp.metric_cards.metrics[0]|regex_search("\((.*?)\)") %} @@ -237,7 +247,7 @@ {% for slice, values in comp.metric_cards.slices|zip(comp.metric_cards.values) %}

{{slice|regex_replace('(? -
+
{% for value in values %} @@ -249,149 +259,7 @@ {% endfor %}
-
- - -
+
{% endmacro %} @@ -1025,4 +893,634 @@ } } generate_model_card_plot(); + + const plot = document.getElementById('plot'); + const inputs_all = document.querySelectorAll('#slice-selection input[type="radio"]'); + const plot_selection = document.querySelectorAll('#plot-selection input[type="radio"]'); + var selections = [null, null, null, null, null, null, null, null, null, null]; + var plot_colors = [ + "rgb(0, 115, 228)", + "rgb(31, 119, 180)", + "rgb(255, 127, 14)", + "rgb(44, 160, 44)", + "rgb(214, 39, 40)", + "rgb(148, 103, 189)", + "rgb(140, 86, 75)", + "rgb(227, 119, 194)", + "rgb(127, 127, 127)", + "rgb(188, 189, 34)", + ]; + + function updatePlotSelection() { + const inputs = document.querySelectorAll('#slice-selection input[type="radio"]:checked'); + var plot_selection = document.querySelectorAll('#plot-selection input[type="radio"]'); + var plot_selected = document.querySelectorAll('#plot-selection input[type="radio"]:checked')[0]; + // get number from value in plot_selected "Plot 1" -> 1 + var label_selection = document.querySelectorAll('#plot-selection label'); + var label_slice_selection = document.querySelectorAll('#slice-selection label'); + + // if plot_selected is "+" then add new radio button to plot_selection called "Plot N" where last plot is N-1 but keep "+" at end and set new radio button to checked for second last element + if (plot_selected.value === "+") { + // if 10 plots already exist, don't add new plot and gray out "+" + if (plot_selection.length === 13) { + plot_selected.checked = false; + label_selection[-1].style.color = "gray"; + return; + } + var new_plot = document.createElement("input"); + new_plot.type = "radio"; + new_plot.id = "Plot " + (plot_selection.length); + new_plot.name = "plot"; + new_plot.value = "Plot " + (plot_selection.length); + new_plot.checked = true; + var new_label = document.createElement("label"); + new_label.htmlFor = "Plot " + (plot_selection.length); + new_label.innerHTML = "Plot " + (plot_selection.length); + + // Parse plot_color to get r, g, b values + var plot_color = plot_colors[plot_selection.length] + const [r, g, b] = plot_color.match(/\d+/g); + const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; + // set background color of new radio button to plot_color + new_label.style.backgroundColor = rgbaColor; + new_label.style.border = "2px solid " + plot_color; + new_label.style.color = plot_color; + + // insert new radio button and label before "+" radio button and after last radio button + plot_selected.insertAdjacentElement("beforebegin", new_plot); + plot_selected.insertAdjacentElement("beforebegin", new_label); + // Add event listener to new radio button + new_plot.addEventListener('change', updatePlotSelection); + + // set plot_selected to new plot + var plot_selected = new_plot + + for (let i = 0; i < label_selection.length-1; i++) { + plot_selection[i].checked = false; + label_selection[i].style.backgroundColor = "#ffffff"; + label_selection[i].style.border = "2px solid #DADCE0"; + label_selection[i].style.color = "#000000"; + } + + selections[parseInt(plot_selected.value.split(" ")[1]-1)] = selections[parseInt(plot_selected.value.split(" ")[1]-2)] + selection = selections[parseInt(plot_selected.value.split(" ")[1]-1)]; + plot_color = plot_colors[parseInt(plot_selected.value.split(" ")[1])]; + + for (let i = 0; i < selection.length; i++) { + // use selection to set label_slice_selection background color + for (let j = 0; j < inputs_all.length; j++) { + if (inputs_all[j].name === selection[i].split(":")[0]) { + if (inputs_all[j].value == selection[i].split(":")[1]) { + const [r, g, b] = plot_color.match(/\d+/g); + const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; + inputs_all[j].checked = true; + label_slice_selection[j].style.backgroundColor = rgbaColor; + label_slice_selection[j].style.border = "2px solid " + plot_color; + label_slice_selection[j].style.color = plot_color; + } + else { + inputs_all[j].checked = false; + label_slice_selection[j].style.backgroundColor = "#ffffff"; + label_slice_selection[j].style.border = "2px solid #DADCE0"; + label_slice_selection[j].style.color = "#000000"; + } + } + } + } + } else { + for (let i = 0; i < plot_selection.length-1; i++) { + if (plot_selection[i].value !== plot_selected.value) { + plot_selection[i].checked = false; + label_selection[i].style.backgroundColor = "#ffffff"; + label_selection[i].style.border = "2px solid #DADCE0"; + label_selection[i].style.color = "#000000"; + } + else { + var plot_color = plot_colors[i+1] + const [r, g, b] = plot_color.match(/\d+/g); + const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; + plot_selected.checked = true; + label_selection[i].style.backgroundColor = rgbaColor; + label_selection[i].style.border = "2px solid " + plot_color; + label_selection[i].style.color = plot_color; + } + } + selection = selections[parseInt(plot_selected.value.split(" ")[1]-1)]; + plot_color = plot_colors[parseInt(plot_selected.value.split(" ")[1])]; + for (let i = 0; i < selection.length; i++) { + // use selection to set label_slice_selection background color + for (let j = 0; j < inputs_all.length; j++) { + if (inputs_all[j].name === selection[i].split(":")[0]) { + if (inputs_all[j].value == selection[i].split(":")[1]) { + inputs_all[j].checked = true; + const [r, g, b] = plot_color.match(/\d+/g); + const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; + label_slice_selection[j].style.backgroundColor = rgbaColor; + label_slice_selection[j].style.border = "2px solid " + plot_color; + label_slice_selection[j].style.color = plot_color; + } + else { + inputs_all[j].checked = false; + label_slice_selection[j].style.backgroundColor = "#ffffff"; + label_slice_selection[j].style.border = "2px solid #DADCE0"; + label_slice_selection[j].style.color = "#000000"; + } + } + } + } + } + var slices_all = JSON.parse({{ get_slices(model_card)|safe|tojson }}); + var histories_all = JSON.parse({{ get_histories(model_card)|safe|tojson }}); + var thresholds_all = JSON.parse({{ get_thresholds(model_card)|safe|tojson }}); + var trends_all = JSON.parse({{ get_trends(model_card)|safe|tojson }}); + var passed_all = JSON.parse({{ get_passed(model_card)|safe|tojson }}); + var names_all = JSON.parse({{ get_names(model_card)|safe|tojson }}); + var timestamps_all = JSON.parse({{ get_timestamps(model_card)|safe|tojson }}); + + var radioGroups = {}; + var labelGroups = {}; + for (let i = 0; i < inputs_all.length; i++) { + var input = inputs_all[i]; + var label = label_slice_selection[i]; + var groupName = input.name; + if (!radioGroups[groupName]) { + radioGroups[groupName] = []; + labelGroups[groupName] = []; + } + radioGroups[groupName].push(input); + labelGroups[groupName].push(label); + } + + // use radioGroups to loop through selection changing only one element at a time + for (let i = 0; i < selection.length; i++) { + for (let j = 0; j < inputs_all.length; j++) { + if (inputs_all[j].name === selection[i].split(":")[0]) { + radio_group = radioGroups[selection[i].split(":")[0]]; + label_group = labelGroups[selection[i].split(":")[0]]; + for (let k = 0; k < radio_group.length; k++) { + selection_copy = selection.slice(); + selection_copy[i] = selection[i].split(":")[0] + ":" + radio_group[k].value; + // get idx of slices where all elements match + var idx = Object.keys(slices_all).find(key => JSON.stringify(slices_all[key].sort()) === JSON.stringify(selection_copy.sort())); + if (idx === undefined) { + // set radio button to disabled and cursor to not allowed and color to gray if idx is undefined + radio_group[k].disabled = true; + label_group[k].style.cursor = "not-allowed"; + label_group[k].style.color = "gray"; + label_group[k].style.backgroundColor = "rgba(125, 125, 125, 0.2)"; + } + else { + radio_group[k].disabled = false; + label_group[k].style.cursor = "pointer"; + } + } + } + } + } + + traces = []; + var plot_number = parseInt(plot_selected.value.split(" ")[1]-1); + for (let i = 0; i < selections.length; i++) { + if (selections[i] === null) { + continue; + } + selection = selections[i] + + // get idx of slices where all elements match + var idx = Object.keys(slices_all).find(key => JSON.stringify(slices_all[key].sort()) === JSON.stringify(selection)); + var history_data = []; + for (let i = 0; i < histories_all[idx].length; i++) { + history_data.push(parseFloat(histories_all[idx][i])); + } + var timestamp_data = []; + for (let i = 0; i < timestamps_all[idx].length; i++) { + timestamp_data.push(timestamps_all[idx][i]); + } + threshold = parseFloat(thresholds_all[idx]); + trend = trends_all[idx]; + passed = passed_all[idx]; + name = names_all[idx]; + + // if trend is "positive" set keyword to upwards, if trend is "negative" set keyword to downwards, else set keyword to flat + if (trend === "positive") { + var trend_keyword = "upwards"; + } else if (trend === "negative") { + var trend_keyword = "downwards"; + } else { + var trend_keyword = "flat"; + } + + // if passed is true set keyword to Above, if passed is false set keyword to Below + if (passed) { + var passed_keyword = "above"; + } + else { + var passed_keyword = "below"; + } + + // create title for plot: Current {metric name} is trending {trend_keyword} and is {passed_keyword} the threshold. + // get number of nulls in selections, if 9 then plot title, else don't plot title + var nulls = 0; + for (let i = 0; i < selections.length; i++) { + if (selections[i] === null) { + nulls += 1; + } + } + if (nulls === 9) { + var plot_title = "Current " + name + " is trending " + trend_keyword + " and is " + passed_keyword + " the threshold."; + var showlegend = false; + } + else { + var plot_title = ""; + var showlegend = true; + } + name = "" + suffix = " ( " + for (let i = 0; i < selection.length; i++) { + if (selection[i].split(":")[0] === "metric") { + name += selection[i].split(":")[1]; + } + else { + if (selection[i].split(":")[1].includes("overall")) { + continue; + } else { + suffix += selection[i]; + suffix += ", "; + } + } + } + if (suffix === " ( ") { + name += ""; + } + else { + suffix = suffix.slice(0, -2); + name += suffix + " )"; + } + + var trace = { + // range of x is the length of the list of floats + x: timestamp_data, + y: history_data, + mode: 'lines+markers', + type: 'scatter', + marker: {color: plot_colors[i+1]}, + line: {color: plot_colors[i+1]}, + name: name, + }; + traces.push(trace); + } + + if (nulls === 9) { + var threshold_trace = { + x: timestamp_data, + y: Array.from({length: history_data.length}, (_, i) => threshold), + mode: 'lines', + type: 'scatter', + marker: {color: 'rgb(0,0,0)'}, + line: {color: 'rgb(0,0,0)', dash: 'dot'}, + name: '', + }; + traces.push(threshold_trace); + } + var layout = { + title: { + text: plot_title, + font: { + family: 'Arial, Helvetica, sans-serif', + size: 18, + } + }, + paper_bgcolor: 'rgba(0,0,0,0)', + plot_bgcolor: 'rgba(0,0,0,0)', + xaxis: { + zeroline: false, + showticklabels: false, + showgrid: false, + }, + yaxis: { + gridcolor: '#ffffff', + }, + showlegend: showlegend, + margin: { + l: 50, + r: 50, + b: 50, + t: 50, + pad: 4 + }, + // set height and width of plot to extra-wide to fit the plot + height: 500, + width: 900, + } + Plotly.newPlot(plot, traces, layout, {displayModeBar: false}); + } + + + + // Define a function to update the plot based on selected filters + function updatePlot() { + const inputs = document.querySelectorAll('#slice-selection input[type="radio"]:checked'); + var plot_selection = document.querySelectorAll('#plot-selection input[type="radio"]'); + var plot_selected = document.querySelectorAll('#plot-selection input[type="radio"]:checked')[0]; + // get number from value in plot_selected "Plot 1" -> 1 + var label_selection = document.querySelectorAll('#plot-selection label'); + var label_slice_selection = document.querySelectorAll('#slice-selection label'); + + // get all inputs values from div class radio-buttons + // get name of inputs + var inputs_name = []; + var inputs_value = []; + for (let i = 0; i < inputs.length; i++) { + inputs_name.push(inputs[i].name); + inputs_value.push(inputs[i].value); + } + + var plot_number = parseInt(plot_selected.value.split(" ")[1]-1); + var selection = []; + for (let i = 0; i < inputs_value.length; i++) { + // check if *_overall in string, if so don't push + // if (inputs_value[i].includes("overall_")) { + // continue; + // } + selection.push(inputs_name[i] + ":" + inputs_value[i]); + } + selection.sort(); + selections[plot_number] = selection; + + // if plot_selected is "+" then add new radio button to plot_selection called "Plot N" where last plot is N-1 but keep "+" at end and set new radio button to checked for second last element + if (plot_selected.value === "+") { + // if 10 plots already exist, don't add new plot and gray out "+" + if (plot_selection.length === 13) { + plot_selected.checked = false; + label_selection[-1].style.color = "gray"; + return; + } + var new_plot = document.createElement("input"); + new_plot.type = "radio"; + new_plot.id = "Plot " + (plot_selection.length); + new_plot.name = "plot"; + new_plot.value = "Plot " + (plot_selection.length); + new_plot.checked = true; + var new_label = document.createElement("label"); + new_label.htmlFor = "Plot " + (plot_selection.length); + new_label.innerHTML = "Plot " + (plot_selection.length); + + // Parse plot_color to get r, g, b values + var plot_color = plot_colors[plot_selection.length] + const [r, g, b] = plot_color.match(/\d+/g); + const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; + // set background color of new radio button to plot_color + new_label.style.backgroundColor = rgbaColor; + new_label.style.border = "2px solid " + plot_color; + new_label.style.color = plot_color; + + // insert new radio button and label before "+" radio button and after last radio button + plot_selected.insertAdjacentElement("beforebegin", new_plot); + plot_selected.insertAdjacentElement("beforebegin", new_label); + // Add event listener to new radio button + new_plot.addEventListener('change', updatePlot); + + // set plot_selected to new plot + plot_selected = new_plot + + for (let i = 0; i < label_selection.length-1; i++) { + plot_selection[i].checked = false; + label_selection[i].style.backgroundColor = "#ffffff"; + label_selection[i].style.border = "2px solid #DADCE0"; + label_selection[i].style.color = "#000000"; + } + } else { + for (let i = 0; i < plot_selection.length-1; i++) { + if (plot_selection[i].value !== plot_selected.value) { + plot_selection[i].checked = false; + label_selection[i].style.backgroundColor = "#ffffff"; + label_selection[i].style.border = "2px solid #DADCE0"; + label_selection[i].style.color = "#000000"; + } + else { + var plot_color = plot_colors[i+1] + const [r, g, b] = plot_color.match(/\d+/g); + const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; + plot_selected.checked = true; + label_selection[i].style.backgroundColor = rgbaColor; + label_selection[i].style.border = "2px solid " + plot_color; + label_selection[i].style.color = plot_color; + } + } + } + var slices_all = JSON.parse({{ get_slices(model_card)|safe|tojson }}); + var histories_all = JSON.parse({{ get_histories(model_card)|safe|tojson }}); + var thresholds_all = JSON.parse({{ get_thresholds(model_card)|safe|tojson }}); + var trends_all = JSON.parse({{ get_trends(model_card)|safe|tojson }}); + var passed_all = JSON.parse({{ get_passed(model_card)|safe|tojson }}); + var names_all = JSON.parse({{ get_names(model_card)|safe|tojson }}); + var timestamps_all = JSON.parse({{ get_timestamps(model_card)|safe|tojson }}); + + for (let i = 0; i < selection.length; i++) { + // use selection to set label_slice_selection background color + for (let j = 0; j < inputs_all.length; j++) { + if (inputs_all[j].name === selection[i].split(":")[0]) { + if (inputs_all[j].value == selection[i].split(":")[1]) { + inputs_all[j].checked = true; + const [r, g, b] = plot_color.match(/\d+/g); + const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; + label_slice_selection[j].style.backgroundColor = rgbaColor; + label_slice_selection[j].style.border = "2px solid " + plot_color; + label_slice_selection[j].style.color = plot_color; + } + else { + inputs_all[j].checked = false; + label_slice_selection[j].style.backgroundColor = "#ffffff"; + label_slice_selection[j].style.border = "2px solid #DADCE0"; + label_slice_selection[j].style.color = "#000000"; + } + } + } + } + + var radioGroups = {}; + var labelGroups = {}; + for (let i = 0; i < inputs_all.length; i++) { + var input = inputs_all[i]; + var label = label_slice_selection[i]; + var groupName = input.name; + if (!radioGroups[groupName]) { + radioGroups[groupName] = []; + labelGroups[groupName] = []; + } + radioGroups[groupName].push(input); + labelGroups[groupName].push(label); + } + + // use radioGroups to loop through selection changing only one element at a time + for (let i = 0; i < selection.length; i++) { + for (let j = 0; j < inputs_all.length; j++) { + if (inputs_all[j].name === selection[i].split(":")[0]) { + radio_group = radioGroups[selection[i].split(":")[0]]; + label_group = labelGroups[selection[i].split(":")[0]]; + for (let k = 0; k < radio_group.length; k++) { + selection_copy = selection.slice(); + selection_copy[i] = selection[i].split(":")[0] + ":" + radio_group[k].value; + // get idx of slices where all elements match + var idx = Object.keys(slices_all).find(key => JSON.stringify(slices_all[key].sort()) === JSON.stringify(selection_copy.sort())); + if (idx === undefined) { + // set radio button to disabled and cursor to not allowed and color to gray if idx is undefined + radio_group[k].disabled = true; + label_group[k].style.cursor = "not-allowed"; + label_group[k].style.color = "gray"; + label_group[k].style.backgroundColor = "rgba(125, 125, 125, 0.2)"; + } + else { + radio_group[k].disabled = false; + label_group[k].style.cursor = "pointer"; + } + } + } + } + } + + traces = []; + for (let i = 0; i < selections.length; i++) { + if (selections[i] === null) { + continue; + } + selection = selections[i] + // get idx of slices where all elements match + var idx = Object.keys(slices_all).find(key => JSON.stringify(slices_all[key].sort()) === JSON.stringify(selection)); + var history_data = []; + for (let i = 0; i < histories_all[idx].length; i++) { + history_data.push(parseFloat(histories_all[idx][i])); + } + var timestamp_data = []; + for (let i = 0; i < timestamps_all[idx].length; i++) { + timestamp_data.push(timestamps_all[idx][i]); + } + threshold = parseFloat(thresholds_all[idx]); + trend = trends_all[idx]; + passed = passed_all[idx]; + name = names_all[idx]; + + // if trend is "positive" set keyword to upwards, if trend is "negative" set keyword to downwards, else set keyword to flat + if (trend === "positive") { + var trend_keyword = "upwards"; + } else if (trend === "negative") { + var trend_keyword = "downwards"; + } else { + var trend_keyword = "flat"; + } + + // if passed is true set keyword to Above, if passed is false set keyword to Below + if (passed) { + var passed_keyword = "above"; + } + else { + var passed_keyword = "below"; + } + + // create title for plot: Current {metric name} is trending {trend_keyword} and is {passed_keyword} the threshold. + // get number of nulls in selections, if 9 then plot title, else don't plot title + var nulls = 0; + for (let i = 0; i < selections.length; i++) { + if (selections[i] === null) { + nulls += 1; + } + } + if (nulls === 9) { + var plot_title = "Current " + name + " is trending " + trend_keyword + " and is " + passed_keyword + " the threshold."; + var showlegend = false; + } + else { + var plot_title = ""; + var showlegend = true; + } + name = "" + suffix = " ( " + for (let i = 0; i < selection.length; i++) { + if (selection[i].split(":")[0] === "metric") { + name += selection[i].split(":")[1]; + } + else { + if (selection[i].split(":")[1].includes("overall")) { + continue; + } else { + suffix += selection[i]; + suffix += ", "; + } + } + } + if (suffix === " ( ") { + name += ""; + } + else { + suffix = suffix.slice(0, -2); + name += suffix + " )"; + } + var trace = { + // range of x is the length of the list of floats + x: timestamp_data, + y: history_data, + mode: 'lines+markers', + type: 'scatter', + marker: {color: plot_colors[i+1]}, + line: {color: plot_colors[i+1]}, + name: name, + //name: selection.toString(), + }; + traces.push(trace); + } + + if (nulls === 9) { + var threshold_trace = { + x: timestamp_data, + y: Array.from({length: history_data.length}, (_, i) => threshold), + mode: 'lines', + type: 'scatter', + marker: {color: 'rgb(0,0,0)'}, + line: {color: 'rgb(0,0,0)', dash: 'dot'}, + name: '', + }; + traces.push(threshold_trace); + } + var layout = { + title: { + text: plot_title, + font: { + family: 'Arial, Helvetica, sans-serif', + size: 18, + } + }, + paper_bgcolor: 'rgba(0,0,0,0)', + plot_bgcolor: 'rgba(0,0,0,0)', + xaxis: { + zeroline: false, + showticklabels: false, + showgrid: false, + }, + yaxis: { + gridcolor: '#ffffff', + }, + showlegend: showlegend, + margin: { + l: 50, + r: 50, + b: 50, + t: 50, + pad: 4 + }, + // set height and width of plot to extra-wide to fit the plot + height: 500, + width: 900, + } + Plotly.newPlot(plot, traces, layout, {displayModeBar: false}); + } + // Add event listeners to radio buttons + for (let input of inputs_all) { + input.addEventListener('change', updatePlot); + } + for (let selection of plot_selection) { + selection.addEventListener('change', updatePlotSelection); + } + // Initial update when the page loads + updatePlot(); + diff --git a/cyclops/report/utils.py b/cyclops/report/utils.py index 440282ab0..cf5a05605 100644 --- a/cyclops/report/utils.py +++ b/cyclops/report/utils.py @@ -468,18 +468,35 @@ def get_slices(model_card: ModelCard) -> str: if ( (model_card.overview is None) or (model_card.overview.metric_cards is None) + or (model_card.overview.metric_cards.slices is None) or (model_card.overview.metric_cards.collection is None) ): pass else: + all_slices = model_card.overview.metric_cards.slices for itr, metric_card in enumerate(model_card.overview.metric_cards.collection): name = ( ["metric:" + metric_card.name] if metric_card.name else ["metric:none"] ) card_slice = metric_card.slice - card_slice_list = card_slice.split("&") if card_slice else "overall" - name.extend(card_slice_list) - name = [e for e in name if e != "overall"] + if card_slice is not None: + if card_slice == "overall": + card_slice_list = [ + f"{slices}:overall_{slices}" for slices in all_slices + ] + else: + card_slice_list = card_slice.split("&") + card_slice_list_split = [ + card_slice.split(":")[0] for card_slice in card_slice_list + ] + + for slices in all_slices: + card_slice_list_split = [ + card_slice.split(":")[0] for card_slice in card_slice_list + ] + if slices not in card_slice_list_split: + card_slice_list.append(f"{slices}:overall_{slices}") + name.extend(card_slice_list) names[itr] = name return json.dumps(names) @@ -590,7 +607,7 @@ def create_metric_cards( # noqa: PLR0912 PLR0915 for current_metric in current_metrics: if current_metric.slice is not None: for slice_val in current_metric.slice.split("&"): - if slice_val not in slices_values: + if slice_val not in slices_values and slice_val != "overall": slices_values.append(slice_val) slices = [ slice_val.split(":")[0] diff --git a/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb b/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb index b1339a327..94c3cbd1a 100644 --- a/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb +++ b/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb @@ -847,6 +847,23 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results_female, _ = mortality_task.evaluate(\n", + " dataset[\"test\"],\n", + " create_metric(metric_name=\"accuracy\", task=\"binary\"),\n", + " model_names=model_name,\n", + " transforms=preprocessor,\n", + " prediction_column_prefix=\"preds\",\n", + " slice_spec=SliceSpec([{\"Sex\": {\"value\": \"F\"}}], include_overall=False),\n", + " batch_size=32,\n", + ")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -870,6 +887,30 @@ " model_name=model_name,\n", ")\n", "\n", + "results_female_flat = flatten_results_dict(\n", + " results=results_female,\n", + " model_name=model_name,\n", + ")\n", + "\n", + "for name, metric in results_female_flat.items():\n", + " split, name = name.split(\"/\") # noqa: PLW2901\n", + " descriptions = {\n", + " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n", + " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", + " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n", + " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n", + " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n", + " }\n", + " report.log_quantitative_analysis(\n", + " \"performance\",\n", + " name=name,\n", + " value=metric,\n", + " description=descriptions[name],\n", + " metric_slice=split,\n", + " pass_fail_thresholds=0.7,\n", + " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n", + " )\n", + "\n", "for name, metric in results_flat.items():\n", " split, name = name.split(\"/\") # noqa: PLW2901\n", " descriptions = {\n", @@ -1194,6 +1235,25 @@ " model_name=model_name,\n", " )\n", "\n", + " for name, metric in results_female_flat.items():\n", + " split, name = name.split(\"/\") # noqa: PLW2901\n", + " descriptions = {\n", + " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n", + " \"BinaryRecall\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", + " \"BinaryAccuracy\": \"The proportion of all instances that are correctly predicted.\",\n", + " \"BinaryAUROC\": \"The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.\",\n", + " \"BinaryF1Score\": \"The harmonic mean of precision and recall.\",\n", + " }\n", + " report.log_quantitative_analysis(\n", + " \"performance\",\n", + " name=name,\n", + " value=np.clip(metric + np.random.normal(0, 0.1), 0, 1),\n", + " description=descriptions[name],\n", + " metric_slice=split,\n", + " pass_fail_thresholds=0.7,\n", + " pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),\n", + " )\n", + "\n", " for name, metric in results_flat.items():\n", " split, name = name.split(\"/\") # noqa: PLW2901\n", " descriptions = {\n", diff --git a/tests/cyclops/report/test_utils.py b/tests/cyclops/report/test_utils.py index a8120c7ea..6de671ac8 100644 --- a/tests/cyclops/report/test_utils.py +++ b/tests/cyclops/report/test_utils.py @@ -252,7 +252,10 @@ def model_card(): """Create a test input for model card.""" model_card = ModelCard() model_card.overview = Overview( + slices=["overall"], metric_cards=MetricCardCollection( + metrics=["BinaryAccuracy", "BinaryPrecision"], + slices=["overall"], collection=[ MetricCard( name="Accuracy", From 38c8c3d5c3f7038b1a44d91aff74f84bf42dd4f4 Mon Sep 17 00:00:00 2001 From: akore Date: Mon, 20 Nov 2023 11:49:23 -0500 Subject: [PATCH 3/4] fix heart_failure_prediction notebook --- .../kaggle/heart_failure_prediction.ipynb | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb b/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb index cb9051107..d415fbe56 100644 --- a/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb +++ b/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb @@ -570,7 +570,7 @@ "source": [ "## Task Creation\n", "\n", - "We use Cyclops tasks to define our model's task (in this case, MortalityPrediction), train the model, make predictions, and evaluate performance. Cyclops task classes encapsulate the entire ML pipeline into a single, cohesive structure, making the process smooth and easy to manage." + "We use Cyclops tasks to define our model's task (in this case, heart failure prediction), train the model, make predictions, and evaluate performance. Cyclops task classes encapsulate the entire ML pipeline into a single, cohesive structure, making the process smooth and easy to manage." ] }, { @@ -581,7 +581,7 @@ }, "outputs": [], "source": [ - "mortality_task = BinaryTabularClassificationTask(\n", + "heart_failure_prediction_task = BinaryTabularClassificationTask(\n", " {model_name: model},\n", " task_features=features_list,\n", " task_target=\"outcome\",\n", @@ -596,7 +596,7 @@ }, "outputs": [], "source": [ - "mortality_task.list_models()" + "heart_failure_prediction_task.list_models()" ] }, { @@ -626,7 +626,7 @@ " \"method\": \"grid\",\n", "}\n", "\n", - "mortality_task.train(\n", + "heart_failure_prediction_task.train(\n", " dataset[\"train\"],\n", " model_name=model_name,\n", " transforms=preprocessor,\n", @@ -640,7 +640,7 @@ "metadata": {}, "outputs": [], "source": [ - "model_params = mortality_task.list_models_params()[model_name]\n", + "model_params = heart_failure_prediction_task.list_models_params()[model_name]\n", "print(model_params)" ] }, @@ -679,7 +679,7 @@ }, "outputs": [], "source": [ - "y_pred = mortality_task.predict(\n", + "y_pred = heart_failure_prediction_task.predict(\n", " dataset[\"test\"],\n", " model_name=model_name,\n", " transforms=preprocessor,\n", @@ -836,9 +836,9 @@ }, "outputs": [], "source": [ - "results, dataset_with_preds = mortality_task.evaluate(\n", - " dataset[\"test\"],\n", - " metric_collection,\n", + "results, dataset_with_preds = heart_failure_prediction_task.evaluate(\n", + " dataset=dataset[\"test\"],\n", + " metrics=metric_collection,\n", " model_names=model_name,\n", " transforms=preprocessor,\n", " prediction_column_prefix=\"preds\",\n", @@ -855,9 +855,11 @@ "metadata": {}, "outputs": [], "source": [ - "results_female, _ = mortality_task.evaluate(\n", - " dataset[\"test\"],\n", - " create_metric(metric_name=\"accuracy\", task=\"binary\"),\n", + "results_female, _ = heart_failure_prediction_task.evaluate(\n", + " dataset=dataset[\"test\"],\n", + " metrics=MetricCollection(\n", + " {\"BinaryAccuracy\": create_metric(metric_name=\"accuracy\", task=\"binary\")},\n", + " ),\n", " model_names=model_name,\n", " transforms=preprocessor,\n", " prediction_column_prefix=\"preds\",\n", From dd0cc4fd26a95ebebbddf1e591cb27cc021cbdac Mon Sep 17 00:00:00 2001 From: akore Date: Tue, 21 Nov 2023 09:19:49 -0500 Subject: [PATCH 4/4] add ability to delete plots --- .../templates/cyclops_generic_template.jinja | 130 +++++++++++++++--- 1 file changed, 114 insertions(+), 16 deletions(-) diff --git a/cyclops/report/templates/cyclops_generic_template.jinja b/cyclops/report/templates/cyclops_generic_template.jinja index a3761ce09..73fe67d46 100644 --- a/cyclops/report/templates/cyclops_generic_template.jinja +++ b/cyclops/report/templates/cyclops_generic_template.jinja @@ -184,7 +184,7 @@ {% endmacro %} {% macro render_perf_over_time(name, comp)%} -
+

How is your model doing over time?


See how your model is performing over several metrics and subgroups over time.

{#
#} @@ -408,6 +408,7 @@ .card { display: flex; flex-wrap: wrap; + flex-basis: 100%; justify-content: left; padding: 1em; border: 1px solid #DADCE0; @@ -773,6 +774,23 @@ } + .radio-buttons #button { + padding-right: 5px; + margin-left: 5px; + margin-right: 0px; + font-size:18px; + font-weight: bold; + cursor: pointer; + color: black; + background-color: #ffffff; + border: 2px solid #DADCE0; + } + + .radio-buttons #button:hover { + color: #0073e4; + } + + @@ -897,7 +915,7 @@ const plot = document.getElementById('plot'); const inputs_all = document.querySelectorAll('#slice-selection input[type="radio"]'); const plot_selection = document.querySelectorAll('#plot-selection input[type="radio"]'); - var selections = [null, null, null, null, null, null, null, null, null, null]; + var selections = [null, null, null, null, null, null, null, null, null, null, null]; var plot_colors = [ "rgb(0, 115, 228)", "rgb(31, 119, 180)", @@ -909,33 +927,97 @@ "rgb(227, 119, 194)", "rgb(127, 127, 127)", "rgb(188, 189, 34)", + "rgb(23, 190, 207)" ]; + function deletePlotSelection(plot_number) { + var plot_selection = document.querySelectorAll('#plot-selection input[type="radio"]'); + var label_selection = document.querySelectorAll('#plot-selection label'); + var label_slice_selection = document.querySelectorAll('#slice-selection label'); + var button_plot_selection = document.querySelectorAll('#plot-selection button'); + + // set last plot to checked + // get plot_selection with name "Plot N" where N is plot_number + for (let i = 0; i < plot_selection.length; i++) { + var plot_name = "Plot " + (plot_number+1) + if (plot_selection[i].value === plot_name) { + plot_number = i; + } + } + plot_selection[plot_number].checked = false; + plot_selection[plot_number-1].checked = true; + + // delete plot_selected and label + plot_selection[plot_number].remove(); + label_selection[plot_number].remove(); + + selections[plot_number] = null; + + // set selection to last plot + selection = selections[plot_number-1]; + plot_color = plot_colors[plot_number-1]; + + // set current plot selection color to plot_color + const [r, g, b] = plot_color.match(/\d+/g); + const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; + plot_selection[plot_number-1].style.backgroundColor = rgbaColor; + plot_selection[plot_number-1].style.border = "2px solid " + plot_color; + plot_selection[plot_number-1].style.color = plot_color; + + // make visibility of delete button from last plot visible + if (button_plot_selection.length >= 2) { + button_plot_selection[button_plot_selection.length-2].style.visibility = "visible"; + } + + for (let i = 0; i < selection.length; i++) { + // use selection to set label_slice_selection background color + for (let j = 0; j < inputs_all.length; j++) { + if (inputs_all[j].name === selection[i].split(":")[0]) { + if (inputs_all[j].value == selection[i].split(":")[1]) { + inputs_all[j].checked = true; + label_slice_selection[j].style.backgroundColor = rgbaColor; + label_slice_selection[j].style.border = "2px solid " + plot_color; + label_slice_selection[j].style.color = plot_color; + } + } + } + } + updatePlot(); + } + function updatePlotSelection() { const inputs = document.querySelectorAll('#slice-selection input[type="radio"]:checked'); var plot_selection = document.querySelectorAll('#plot-selection input[type="radio"]'); var plot_selected = document.querySelectorAll('#plot-selection input[type="radio"]:checked')[0]; // get number from value in plot_selected "Plot 1" -> 1 + var plot_number = parseInt(plot_selected.value.split(" ")[1]); var label_selection = document.querySelectorAll('#plot-selection label'); var label_slice_selection = document.querySelectorAll('#slice-selection label'); + var button_plot_selection = document.querySelectorAll('#plot-selection button'); // if plot_selected is "+" then add new radio button to plot_selection called "Plot N" where last plot is N-1 but keep "+" at end and set new radio button to checked for second last element if (plot_selected.value === "+") { // if 10 plots already exist, don't add new plot and gray out "+" - if (plot_selection.length === 13) { + if (plot_selection.length === 11) { plot_selected.checked = false; - label_selection[-1].style.color = "gray"; + label_selection[label_selection.length-1].style.color = "gray"; return; } + // plot_name should be name of last plot + 1 + if (plot_selection.length === 2) { + var plot_name = "Plot 2" + } else { + var plot_name = "Plot " + (parseInt(plot_selection[plot_selection.length - 2].value.split(" ")[1]) + 1); + } var new_plot = document.createElement("input"); new_plot.type = "radio"; - new_plot.id = "Plot " + (plot_selection.length); + new_plot.id = plot_name; new_plot.name = "plot"; - new_plot.value = "Plot " + (plot_selection.length); + new_plot.value = plot_name; new_plot.checked = true; var new_label = document.createElement("label"); - new_label.htmlFor = "Plot " + (plot_selection.length); - new_label.innerHTML = "Plot " + (plot_selection.length); + new_label.htmlFor = plot_name; + new_label.innerHTML = plot_name; // Parse plot_color to get r, g, b values var plot_color = plot_colors[plot_selection.length] @@ -946,9 +1028,28 @@ new_label.style.border = "2px solid " + plot_color; new_label.style.color = plot_color; + // add button to delete plot + var delete_button = document.createElement("button"); + delete_button.id = "button"; + delete_button.innerHTML = "×"; + delete_button.style.backgroundColor = "transparent"; + delete_button.style.border = "none"; + new_label.style.padding = "1.5px 0px"; + new_label.style.paddingLeft = "10px"; + + new_label.appendChild(delete_button) + + // make delete button from last plot invisible if not Plot 1 + if (plot_selection.length > 2) { + button_plot_selection[button_plot_selection.length-1].style.visibility = "hidden"; + } + // add on_click event to delete button and send plot number to deletePlotSelection + delete_button.onclick = function() {deletePlotSelection(plot_number)}; + // insert new radio button and label before "+" radio button and after last radio button plot_selected.insertAdjacentElement("beforebegin", new_plot); plot_selected.insertAdjacentElement("beforebegin", new_label); + // Add event listener to new radio button new_plot.addEventListener('change', updatePlotSelection); @@ -1120,13 +1221,14 @@ // create title for plot: Current {metric name} is trending {trend_keyword} and is {passed_keyword} the threshold. // get number of nulls in selections, if 9 then plot title, else don't plot title + console.log(selections) var nulls = 0; for (let i = 0; i < selections.length; i++) { if (selections[i] === null) { nulls += 1; } } - if (nulls === 9) { + if (nulls === 10) { var plot_title = "Current " + name + " is trending " + trend_keyword + " and is " + passed_keyword + " the threshold."; var showlegend = false; } @@ -1170,7 +1272,7 @@ traces.push(trace); } - if (nulls === 9) { + if (nulls === 10) { var threshold_trace = { x: timestamp_data, y: Array.from({length: history_data.length}, (_, i) => threshold), @@ -1238,10 +1340,6 @@ var plot_number = parseInt(plot_selected.value.split(" ")[1]-1); var selection = []; for (let i = 0; i < inputs_value.length; i++) { - // check if *_overall in string, if so don't push - // if (inputs_value[i].includes("overall_")) { - // continue; - // } selection.push(inputs_name[i] + ":" + inputs_value[i]); } selection.sort(); @@ -1425,7 +1523,7 @@ nulls += 1; } } - if (nulls === 9) { + if (nulls === 10) { var plot_title = "Current " + name + " is trending " + trend_keyword + " and is " + passed_keyword + " the threshold."; var showlegend = false; } @@ -1469,7 +1567,7 @@ traces.push(trace); } - if (nulls === 9) { + if (nulls === 10) { var threshold_trace = { x: timestamp_data, y: Array.from({length: history_data.length}, (_, i) => threshold),